hyper_scripter/
fuzzy.rs

1use crate::error::Result;
2use crate::state::State;
3use futures::future::join_all;
4use fuzzy_matcher::{
5    skim::{SkimMatcherV2, SkimScoreConfig},
6    FuzzyMatcher,
7};
8use std::borrow::Cow;
9use std::cell::UnsafeCell;
10use std::cmp::{Ordering, PartialOrd};
11use tokio::task::spawn_blocking;
12
13const MID_SCORE: i64 = 800; // TODO: 好好決定這個魔法數字
14const EXACXT_SCORE: i64 = std::i64::MAX / 1000; // 代表「完全相符」時的分數
15
16fn is_multifuzz(score: i64, best_score: i64) -> bool {
17    best_score - score < 2 // 吃掉「正常排序就命中」的差異
18}
19
20#[derive(Debug)]
21pub enum FuzzResult<T> {
22    High(T),
23    Low(T),
24    Multi {
25        ans: T,
26        others: Vec<T>,
27        still_others: Vec<T>,
28    },
29}
30pub use FuzzResult::*;
31impl<T> FuzzResult<T> {
32    fn new_single(ans: T, score: FuzzScore) -> Self {
33        let score = score.score * 100 / score.len as i64;
34        match score {
35            0..=MID_SCORE => Low(ans),
36            _ => High(ans),
37        }
38    }
39    fn new_multi(ans: T, others: Vec<T>, still_others: Vec<T>) -> Self {
40        Multi {
41            ans,
42            others,
43            still_others,
44        }
45    }
46    pub fn get_ans(self) -> T {
47        match self {
48            High(t) => t,
49            Low(t) => t,
50            Multi { ans, .. } => ans,
51        }
52    }
53}
54
55static MATCHER: State<SkimMatcherV2> = State::new();
56
57pub trait FuzzKey {
58    fn fuzz_key(&self) -> Cow<'_, str>;
59}
60impl<T: AsRef<str>> FuzzKey for T {
61    fn fuzz_key(&self) -> Cow<'_, str> {
62        Cow::Borrowed(self.as_ref())
63    }
64}
65
66#[derive(Copy, Clone)]
67struct MyRaw<T>(T);
68unsafe impl<T> Send for MyRaw<T> {}
69impl MyRaw<*const str> {
70    fn new(s: &str) -> MyRaw<*const str> {
71        MyRaw(s as *const str)
72    }
73    unsafe fn as_str(&self) -> &str {
74        &*self.0
75    }
76}
77impl<T: Copy> MyRaw<T> {
78    fn get(&self) -> T {
79        self.0
80    }
81}
82
83enum MyCow {
84    Borrowed(MyRaw<*const str>),
85    Owned(String),
86}
87impl MyCow {
88    fn new(s: Cow<'_, str>) -> Self {
89        match s {
90            Cow::Borrowed(s) => MyCow::Borrowed(MyRaw::new(s)),
91            Cow::Owned(s) => MyCow::Owned(s),
92        }
93    }
94    unsafe fn get(&self) -> &str {
95        match self {
96            MyCow::Borrowed(s) => s.as_str(),
97            MyCow::Owned(s) => &*s,
98        }
99    }
100}
101
102#[derive(Default, PartialEq, Eq, Debug, Clone, Copy)]
103pub struct FuzzScore {
104    len: usize,
105    score: i64,
106}
107impl FuzzScore {
108    fn is_default(&self) -> bool {
109        self.len == 0
110    }
111}
112impl PartialOrd for FuzzScore {
113    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
114        Some(
115            if (self.score - other.score).abs() < 2 && self.len != other.len {
116                // 吃掉「正常排序就命中」的差異
117                other.len.cmp(&self.len)
118            } else {
119                self.score.cmp(&other.score)
120            },
121        )
122    }
123}
124impl Ord for FuzzScore {
125    fn cmp(&self, other: &Self) -> Ordering {
126        self.partial_cmp(other).unwrap()
127    }
128}
129pub async fn fuzz<'a, T: FuzzKey + Send + 'a>(
130    name: &str,
131    iter: impl Iterator<Item = T>,
132    sep: &str,
133) -> Result<Option<FuzzResult<T>>> {
134    fuzz_with_multifuzz_ratio(name, iter, sep, None).await
135}
136
137/// multifuzz_ratio 為百分比的數字部份
138pub async fn fuzz_with_multifuzz_ratio<'a, T: FuzzKey + Send + 'a>(
139    name: &str,
140    iter: impl Iterator<Item = T>,
141    sep: &str,
142    multifuzz_ratio: Option<i64>,
143) -> Result<Option<FuzzResult<T>>> {
144    let raw_name = MyRaw::new(name);
145    let mut data_vec: Vec<_> = iter
146        .map(|t| (UnsafeCell::new(FuzzScore::default()), t))
147        .collect();
148    let sep = MyRaw::new(sep);
149    let has_ratio = multifuzz_ratio.is_some();
150
151    crate::set_once!(MATCHER, || {
152        let mut conf = SkimScoreConfig::default();
153        conf.bonus_consecutive *= 4;
154        SkimMatcherV2::default().score_config(conf)
155    });
156
157    let score_fut = data_vec.iter_mut().map(|(score, data)| {
158        let key = MyCow::new(data.fuzz_key());
159        let score_ptr = MyRaw(score.get());
160        spawn_blocking(move || {
161            // SAFTY: 等等就會 join,故這個函式 await 完之前都不可能釋放這些字串
162            let key = unsafe { key.get() };
163            let score = my_fuzz(
164                key,
165                unsafe { raw_name.as_str() },
166                unsafe { sep.as_str() },
167                !has_ratio,
168            );
169
170            if let Some(score) = score {
171                let len = key.chars().count();
172                // SAFETY: 怎麼可能有多個人持有同個元素的分數
173                assert_ne!(len, 0);
174                unsafe {
175                    *score_ptr.get() = FuzzScore { score, len };
176                }
177            }
178        })
179    });
180
181    join_all(score_fut).await;
182    // NOTE: 算分數就別平行做了,不然要搞原子性,可能得不償失
183    let best_score = data_vec
184        .iter_mut()
185        .map(|(score, _)| *score.get_mut())
186        .max()
187        .unwrap_or_default();
188
189    if best_score.is_default() {
190        log::info!("模糊搜沒搜到東西 {}", name);
191        return Ok(None);
192    }
193
194    let best_score_normalized =
195        multifuzz_ratio.map_or(best_score.score, |r| best_score.score * r / 100);
196
197    let mut ans = None;
198    let mut multifuzz_vec = vec![];
199    let mut secondary_multifuzz_vec = vec![];
200    for (score, data) in data_vec.into_iter() {
201        let score = score.into_inner();
202        if score == best_score && ans.is_none() {
203            ans = Some(data);
204        } else if is_multifuzz(score.score, best_score.score) {
205            log::warn!("找到一個分數相近者:{} {:?}", data.fuzz_key(), score);
206            multifuzz_vec.push(data);
207        } else if is_multifuzz(score.score, best_score_normalized) {
208            log::warn!("找到一個分數稍微相近者:{} {:?}", data.fuzz_key(), score);
209            secondary_multifuzz_vec.push(data);
210        }
211    }
212
213    let ans = ans.unwrap();
214    if multifuzz_vec.is_empty() && secondary_multifuzz_vec.is_empty() {
215        log::info!("模糊搜到一個東西 {}", ans.fuzz_key());
216        Ok(Some(FuzzResult::new_single(ans, best_score)))
217    } else {
218        log::warn!(
219            "模糊搜到太多東西,主要為結果為 {} {:?}",
220            ans.fuzz_key(),
221            best_score
222        );
223        Ok(Some(FuzzResult::new_multi(
224            ans,
225            multifuzz_vec,
226            secondary_multifuzz_vec,
227        )))
228    }
229}
230
231// TODO: 把這些 sep: &str 換成標準庫的 Pattern
232
233pub fn is_prefix(prefix: &str, target: &str, sep: &str) -> bool {
234    if prefix.len() > target.len() {
235        return false;
236    }
237
238    let mut found = false;
239    foreach_reorder(target, sep, &mut |t| {
240        foreach_reorder(prefix, sep, &mut |p| {
241            if t.starts_with(p) {
242                found = true;
243            }
244            found
245        });
246        found
247    });
248
249    found
250}
251
252fn my_fuzz(mut choice: &str, pattern: &str, sep: &str, boost_exact: bool) -> Option<i64> {
253    if boost_exact && choice == pattern {
254        return Some(EXACXT_SCORE);
255    }
256
257    if choice.chars().next() == Some('.') && pattern.chars().next() != Some('.') {
258        log::trace!("拔掉匿名腳本的 `.` 前綴");
259        choice = &choice[1..];
260    }
261
262    let mut ans_opt = None;
263    let mut first = true;
264    foreach_reorder(choice, sep, &mut |choice_reordered| {
265        let score_opt = MATCHER.get().fuzzy_match(choice_reordered, pattern);
266        log::trace!(
267            "模糊搜尋,候選者:{},重排列成:{},輸入:{},分數:{:?}",
268            choice,
269            choice_reordered,
270            pattern,
271            score_opt,
272        );
273        if let Some(mut score) = score_opt {
274            if first {
275                // NOTE: 正常排序的分數會稍微高一點
276                // 例如 [a/b, b/a] 中要找 `a/b`,則前者以分毫之差勝出
277                score += 1;
278                log::trace!(
279                    "模糊搜尋,候選者:{},正常排序就命中,分數略提升為 {}",
280                    choice,
281                    score
282                );
283            }
284
285            if let Some(ans) = ans_opt {
286                ans_opt = Some(std::cmp::max(score, ans));
287            } else {
288                ans_opt = Some(score);
289            }
290        }
291        first = false;
292    });
293    ans_opt
294}
295
296trait StopIndicator: Default {
297    fn should_stop(&self) -> bool {
298        false
299    }
300}
301impl StopIndicator for () {}
302impl StopIndicator for bool {
303    fn should_stop(&self) -> bool {
304        *self
305    }
306}
307fn foreach_reorder<S: StopIndicator, F: FnMut(&str) -> S>(
308    choice: &str,
309    sep: &str,
310    handler: &mut F,
311) {
312    fn recursive_reorder<'a, S: StopIndicator, F: FnMut(&str) -> S>(
313        choice_arr: &[&'a str],
314        mem: &mut Vec<bool>,
315        reorderd: &mut Vec<&'a str>,
316        sep: &str,
317        handler: &mut F,
318    ) -> S {
319        if reorderd.len() == mem.len() {
320            let new_str = reorderd.join(sep);
321            handler(&new_str)
322        } else {
323            for i in 0..mem.len() {
324                if mem[i] {
325                    continue;
326                }
327                mem[i] = true;
328                reorderd.push(choice_arr[i]);
329                let indicator = recursive_reorder(choice_arr, mem, reorderd, sep, handler);
330                if indicator.should_stop() {
331                    return indicator;
332                }
333                reorderd.pop();
334                mem[i] = false;
335            }
336            Default::default()
337        }
338    }
339
340    let choice_arr: Vec<_> = choice.split(sep).collect();
341    let mut mem = vec![false; choice_arr.len()];
342    let mut reorederd = Vec::<&str>::with_capacity(mem.len());
343    recursive_reorder(&choice_arr, &mut mem, &mut reorederd, sep, handler);
344}
345
346#[cfg(test)]
347mod test {
348    use super::*;
349    use crate::my_env_logger;
350
351    fn extract_multifuzz<'a>(res: FuzzResult<&'a str>) -> (&'a str, Vec<&'a str>) {
352        match res {
353            Multi { ans, others, .. } => {
354                let mut ret = vec![];
355                ret.push(ans);
356                for data in others.into_iter() {
357                    ret.push(data);
358                }
359                ret.sort();
360                (ans, ret)
361            }
362            _ => unreachable!("{:?}", res),
363        }
364    }
365    fn extract_high<'a>(res: FuzzResult<&'a str>) -> &'a str {
366        match res {
367            High(t) => t,
368            _ => unreachable!("{:?}", res),
369        }
370    }
371    async fn do_fuzz<'a>(name: &'a str, v: &'a Vec<&'a str>) -> Option<FuzzResult<&'a str>> {
372        fuzz(name, v.iter().map(|s| *s), ":").await.unwrap()
373    }
374    #[tokio::test(flavor = "multi_thread")]
375    async fn test_fuzz() {
376        let _ = my_env_logger::try_init();
377        let t1 = "測試腳本1";
378        let t2 = "測試腳本2";
379        let t3 = ".42";
380        let t4 = "測腳本4試";
381        let vec = vec![t1, t2, t3, t4];
382
383        let res = do_fuzz("測試1", &vec).await.unwrap();
384        assert_eq!(extract_high(res), t1);
385
386        let res = do_fuzz("42", &vec).await.unwrap();
387        assert_eq!(extract_high(res), t3);
388
389        let res = do_fuzz("找不到", &vec).await;
390        assert!(res.is_none());
391
392        let res = do_fuzz("測試", &vec).await.unwrap();
393        let (ans, v) = extract_multifuzz(res);
394        assert_eq!(v, vec!["測試腳本1", "測試腳本2"]);
395        assert_eq!(ans, "測試腳本1"); // 真的同分,只好以順序決定了
396    }
397    #[tokio::test(flavor = "multi_thread")]
398    async fn test_fuzz_with_len() {
399        let _ = my_env_logger::try_init();
400        let t1 = "測試腳本1";
401        let t2 = "測試腳本23456";
402        let vec = vec![t1, t2];
403        let res = do_fuzz("測試", &vec).await.unwrap();
404        let (ans, v) = extract_multifuzz(res);
405        assert_eq!(ans, "測試腳本1");
406        assert_eq!(v, vec);
407    }
408    #[tokio::test(flavor = "multi_thread")]
409    async fn test_reorder_fuzz() {
410        let _ = my_env_logger::try_init();
411        let t1 = "a:c";
412        let t2 = "b:a";
413        let t3 = "a:b";
414        let vec = vec![t1, t2, t3];
415
416        let res = do_fuzz("ab", &vec).await.unwrap();
417        let (ans, v) = extract_multifuzz(res);
418        assert_eq!(ans, "a:b"); // 正常排序就命中
419        assert_eq!(v, vec!["a:b", "b:a"]);
420
421        let res = do_fuzz("ba", &vec).await.unwrap();
422        let (ans, v) = extract_multifuzz(res);
423        assert_eq!(ans, "b:a"); // 正常排序就命中
424        assert_eq!(v, vec!["a:b", "b:a"]);
425
426        let res = do_fuzz("ca", &vec).await.unwrap();
427        assert_eq!(extract_high(res), "a:c");
428
429        let res = do_fuzz("a", &vec).await.unwrap();
430        let (ans, v) = extract_multifuzz(res);
431        assert_eq!(ans, "a:c"); // 真的同分,只好以順序決定了
432        assert_eq!(v, vec!["a:b", "a:c", "b:a"]);
433
434        let vec = vec!["a:b:c", "c:a"];
435        let res = do_fuzz("a", &vec).await.unwrap();
436        let (ans, v) = extract_multifuzz(res);
437        assert_eq!(ans, "c:a");
438        assert_eq!(v, vec!["a:b:c", "c:a"]);
439    }
440    #[test]
441    fn test_reorder() {
442        let arr = "aa::bb::cc";
443        let mut buffer = vec![];
444        foreach_reorder(arr, "::", &mut |s| {
445            buffer.push(s.to_owned());
446        });
447        buffer.sort();
448        assert_eq!(
449            vec![
450                "aa::bb::cc",
451                "aa::cc::bb",
452                "bb::aa::cc",
453                "bb::cc::aa",
454                "cc::aa::bb",
455                "cc::bb::aa"
456            ],
457            buffer
458        );
459    }
460    #[test]
461    fn test_is_prefix() {
462        let sep = "::";
463        assert!(is_prefix("aa", "aabb", sep));
464        assert!(is_prefix("aa::bb", "bb::cc::aa", sep));
465        assert!(is_prefix("c", "bb::cc::aa", sep));
466        assert!(is_prefix("aa::bb", "bb::aa1", sep));
467
468        assert!(is_prefix("aa::b", "bb::cc::aa", sep));
469        assert!(is_prefix("a::bb", "bb::cc::aa", sep));
470
471        assert!(!is_prefix("abb", "aabb", sep));
472        assert!(!is_prefix("aabb", "aa::bb", sep));
473
474        assert!(!is_prefix("aa::bb::cc", "aa::bb", sep));
475        assert!(!is_prefix("aa::dd", "bb::cc::aa", sep));
476    }
477}