drain3 0.1.6

Fast log template extraction via fixed-depth prefix trees (Rust port of logpai/Drain3)
Documentation
use drain3::{train, Config, Error};
use std::collections::HashMap;

fn render_template_placeholders(t: &drain3::Template, param_str: &str) -> String {
    let mut out: Vec<String> = Vec::with_capacity(t.token_count());
    let mut dense_idx = 0;
    for i in 0..t.token_count() {
        if t.is_param(i) {
            out.push(param_str.to_string());
        } else {
            out.push(t.tokens()[dense_idx].to_string());
            dense_idx += 1;
        }
    }
    out.join(" ")
}

#[test]
fn logpai_sshd_scenario() {
    let samples: Vec<String> = vec![
        "Dec 10 07:07:38 LabSZ sshd[24206]: input_userauth_request: invalid user test9 [preauth]"
            .into(),
        "Dec 10 07:08:28 LabSZ sshd[24208]: input_userauth_request: invalid user webmaster [preauth]"
            .into(),
        "Dec 10 09:12:32 LabSZ sshd[24490]: Failed password for invalid user ftpuser from 0.0.0.0 port 62891 ssh2"
            .into(),
        "Dec 10 09:12:35 LabSZ sshd[24492]: Failed password for invalid user pi from 0.0.0.0 port 49289 ssh2"
            .into(),
        "Dec 10 09:12:44 LabSZ sshd[24501]: Failed password for invalid user ftpuser from 0.0.0.0 port 60836 ssh2"
            .into(),
        "Dec 10 07:28:03 LabSZ sshd[24245]: input_userauth_request: invalid user pgadmin [preauth]"
            .into(),
    ];
    let cfg = Config::builder().similarity_threshold(0.4).build();
    let m = train(&samples, cfg.clone()).unwrap();
    let mut want: HashMap<String, usize> = HashMap::new();
    want.insert(
        "Dec 10 <*> LabSZ <*> input_userauth_request: invalid user <*> [preauth]".into(),
        3,
    );
    want.insert(
        "Dec 10 <*> LabSZ <*> Failed password for invalid user <*> from 0.0.0.0 port <*> ssh2"
            .into(),
        3,
    );
    let mut got: HashMap<String, usize> = HashMap::new();
    let mut total = 0;
    for tmpl in m.templates() {
        let key = render_template_placeholders(tmpl, cfg.param_string.as_ref());
        *got.entry(key).or_insert(0) += tmpl.count();
        total += tmpl.count();
    }
    assert_eq!(got, want, "templates mismatch");
    assert_eq!(total, samples.len(), "total count mismatch");
}

#[test]
fn logpai_sshd_scenario_high_sim() {
    let samples: Vec<String> = vec![
        "Dec 10 07:07:38 LabSZ sshd[24206]: input_userauth_request: invalid user test9 [preauth]"
            .into(),
        "Dec 10 07:08:28 LabSZ sshd[24208]: input_userauth_request: invalid user webmaster [preauth]"
            .into(),
        "Dec 10 09:12:32 LabSZ sshd[24490]: Failed password for invalid user ftpuser from 0.0.0.0 port 62891 ssh2"
            .into(),
        "Dec 10 09:12:35 LabSZ sshd[24492]: Failed password for invalid user pi from 0.0.0.0 port 49289 ssh2"
            .into(),
        "Dec 10 09:12:44 LabSZ sshd[24501]: Failed password for invalid user ftpuser from 0.0.0.0 port 60836 ssh2"
            .into(),
        "Dec 10 07:28:03 LabSZ sshd[24245]: input_userauth_request: invalid user pgadmin [preauth]"
            .into(),
    ];
    let cfg = Config::builder().similarity_threshold(0.75).build();
    let m = train(&samples, cfg.clone()).unwrap();
    let mut want: HashMap<String, usize> = HashMap::new();
    want.insert(samples[0].clone(), 1);
    want.insert(samples[1].clone(), 1);
    want.insert(
        "Dec 10 <*> LabSZ <*> Failed password for invalid user <*> from 0.0.0.0 port <*> ssh2"
            .into(),
        3,
    );
    want.insert(samples[5].clone(), 1);
    let mut got: HashMap<String, usize> = HashMap::new();
    let mut total = 0;
    for tmpl in m.templates() {
        let key = render_template_placeholders(tmpl, cfg.param_string.as_ref());
        *got.entry(key).or_insert(0) += tmpl.count();
        total += tmpl.count();
    }
    assert_eq!(got, want, "templates mismatch");
    assert_eq!(total, samples.len(), "total count mismatch");
}

#[test]
fn logpai_short_message() {
    let m = train(
        &["hello".into(), "hello".into(), "otherword".into()],
        Config::default(),
    )
    .unwrap();
    let mut got: HashMap<String, usize> = HashMap::new();
    for tmpl in m.templates() {
        let key = render_template_placeholders(tmpl, "<*>");
        *got.entry(key).or_insert(0) += tmpl.count();
    }
    let mut want: HashMap<String, usize> = HashMap::new();
    want.insert("hello".into(), 2);
    want.insert("otherword".into(), 1);
    assert_eq!(got, want, "templates mismatch");
}

#[test]
fn logpai_match_only() {
    let m = train(
        &[
            "aa aa aa".into(),
            "aa aa bb".into(),
            "aa aa cc".into(),
            "xx yy zz".into(),
        ],
        Config::default(),
    )
    .unwrap();
    let cases: Vec<(&str, usize)> = vec![
        ("aa aa tt", 1),
        ("xx yy zz", 2),
        ("xx yy rr", 0),
        ("nothing", 0),
    ];
    for (line, want) in cases {
        let id = m.match_id(line);
        if want == 0 {
            assert!(
                id.is_none(),
                "Match({line:?}): got id={id:?}, want no match"
            );
        } else {
            assert_eq!(
                id,
                Some(want),
                "Match({line:?}): got id={id:?}, want id={want}"
            );
        }
    }
}

#[test]
fn deterministic_templates() {
    let samples: Vec<String> = vec![
        "svc 1 INFO user 10".into(),
        "svc 2 INFO user 20".into(),
        "svc 3 ERROR user 30".into(),
        "svc 4 ERROR user 40".into(),
    ];
    let m1 = train(&samples, Config::default()).unwrap();
    let m2 = train(&samples, Config::default()).unwrap();
    assert_eq!(
        m1.templates(),
        m2.templates(),
        "templates are not deterministic"
    );
}

#[test]
fn train_handles_empty_input() {
    let m = train(&[], Config::default()).unwrap();
    assert!(m.templates().is_empty(), "expected no templates");
    assert!(m.match_id("anything").is_none(), "expected no match");
}

#[test]
fn zero_thresholds_are_valid() {
    let cfg = Config::builder()
        .similarity_threshold(0.0)
        .match_threshold(0.0)
        .build();
    let m = train(&["A B C".into(), "A B D".into()], cfg).unwrap();
    assert!(
        m.match_id("A X Y").is_some(),
        "expected match with 0.0 match threshold"
    );
    assert_eq!(
        m.templates().len(),
        1,
        "expected 1 template with 0.0 similarity"
    );
}

#[test]
fn max_clusters() {
    let lines: Vec<String> = vec![
        "alpha X Y".into(),
        "bravo X Y".into(),
        "charlie X Y".into(),
        "delta X Y".into(),
        "echo X Y".into(),
    ];
    let cfg = Config::builder().max_clusters(2).build();
    let result = train(&lines, cfg);
    assert!(result.is_err(), "train with max_clusters=2 should fail");
    let err = match result {
        Ok(_) => panic!("expected error"),
        Err(e) => e,
    };
    assert!(
        matches!(err, Error::MaxClustersReached { .. }),
        "expected MaxClustersReached"
    );
    let cfg = Config::builder().max_clusters(0).build();
    let full = train(&lines, cfg).unwrap();
    assert!(
        full.templates().len() > 2,
        "expected uncapped training to produce more than 2 templates: {}",
        full.templates().len()
    );
}

#[test]
fn train_validation() {
    let cfg = Config::builder().depth(2).build();
    assert!(
        train(&["a b c".into()], cfg).is_err(),
        "expected error for invalid depth"
    );
}

#[test]
fn zero_value_config_is_rejected() {
    let zero_cfg = Config::builder()
        .depth(0)
        .similarity_threshold(0.0)
        .match_threshold(0.0)
        .max_children(0)
        .max_tokens(0)
        .max_bytes(0)
        .max_clusters(0)
        .param_string(String::new().into())
        .parametrize_numeric_tokens(false)
        .extra_delimiters(vec![])
        .enable_match_prefilter(false)
        .build();
    assert!(
        train(&["a b c".into()], zero_cfg).is_err(),
        "expected error for zero-value Config"
    );
}

#[test]
fn extra_delimiters() {
    let cfg = Config::builder().extra_delimiters(vec!["=".into()]).build();
    let m = train(&["k=v a=1".into(), "k=v a=2".into()], cfg).unwrap();
    let (id, args, ok) = m.match_line("k=v a=7");
    assert!(ok, "expected match");
    assert_eq!(id, 1, "expected template id 1, got {id}");
    assert_eq!(args, vec!["7"], "unexpected args: {args:?}");
}

#[test]
fn match_into() {
    let samples: Vec<String> = vec![
        "service 1 level INFO user 10 action 5".into(),
        "service 2 level INFO user 20 action 5".into(),
        "service 3 level INFO user 30 action 5".into(),
    ];
    let m = train(&samples, Config::default()).unwrap();
    let line = "service 99 level INFO user 777 action 5";
    let (id_a, args_a, ok_a) = m.match_line(line);
    let mut scratch: Vec<String> = Vec::with_capacity(8);
    let (id_b, ok_b) = m.match_into(line, &mut scratch);
    assert_eq!(id_a, id_b, "MatchInto id mismatch");
    assert_eq!(ok_a, ok_b, "MatchInto ok mismatch");
    assert_eq!(args_a, scratch, "MatchInto args mismatch");
    assert!(!scratch.is_empty(), "expected extracted params");
    scratch.clear();
    let (_, ok_miss) = m.match_into("short unmatched", &mut scratch);
    assert!(!ok_miss, "expected no match");
    assert!(
        scratch.is_empty(),
        "expected empty args on miss, got {scratch:?}"
    );
}

#[test]
fn config_and_templates_are_copied() {
    let cfg = Config::builder().extra_delimiters(vec!["=".into()]).build();
    let m = train(&["k=v a=1".into(), "k=v a=2".into()], cfg).unwrap();
    let read_cfg = m.config();
    assert_eq!(
        read_cfg.extra_delimiters[0], "=",
        "config getter leaked mutable slice"
    );
    let templates = m.templates();
    assert_eq!(
        templates[0].tokens()[0],
        m.templates()[0].tokens()[0],
        "templates getter leaked mutable data"
    );
}

#[test]
fn concurrent_find_is_sync_safe() {
    use std::sync::Arc;
    use std::thread;
    let m = train(
        &["alpha 123".into(), "beta 456".into(), "gamma 789".into()],
        Config::default(),
    )
    .unwrap();
    let m = Arc::new(m);
    let handles: Vec<_> = (0..4)
        .map(|_| {
            let m = Arc::clone(&m);
            thread::spawn(move || {
                for _ in 0..1000 {
                    m.find("alpha 999");
                    m.find("beta 888");
                    m.find("gamma 777");
                    m.find("delta 666");
                }
            })
        })
        .collect();
    for h in handles {
        h.join().expect("thread panicked");
    }
}