use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Intent {
SparqlQuery,
RdfInsert,
RdfDelete,
SchemaQuestion,
FactLookup,
Navigation,
Help,
Greeting,
Farewell,
Unknown,
}
impl std::fmt::Display for Intent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IntentRule {
pub intent: Intent,
pub keywords: Vec<String>,
pub patterns: Vec<String>,
pub weight: f64,
}
impl IntentRule {
fn new(intent: Intent, keywords: &[&str], patterns: &[&str], weight: f64) -> Self {
Self {
intent,
keywords: keywords.iter().map(|s| s.to_string()).collect(),
patterns: patterns.iter().map(|s| s.to_string()).collect(),
weight,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationResult {
pub intent: Intent,
pub confidence: f64,
pub matched_keywords: Vec<String>,
pub matched_patterns: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct IntentClassifier {
rules: Vec<IntentRule>,
}
impl Default for IntentClassifier {
fn default() -> Self {
Self::new()
}
}
impl IntentClassifier {
pub fn new() -> Self {
let rules = vec![
IntentRule::new(
Intent::SparqlQuery,
&[
"select",
"where",
"sparql",
"query",
"ask",
"construct",
"describe",
"prefix",
],
&["SELECT * WHERE", "PREFIX ", "ASK {", "CONSTRUCT {"],
1.5,
),
IntentRule::new(
Intent::RdfInsert,
&["insert", "add", "create", "triple", "assert"],
&["insert data", "add triple", "INSERT DATA"],
1.2,
),
IntentRule::new(
Intent::RdfDelete,
&["delete", "remove", "drop", "retract"],
&["delete data", "remove triple", "DROP GRAPH"],
1.2,
),
IntentRule::new(
Intent::SchemaQuestion,
&[
"schema",
"ontology",
"class",
"property",
"define",
"definition",
"concept",
"subclass",
"domain",
"range",
],
&["what is", "what are", "define "],
1.0,
),
IntentRule::new(
Intent::FactLookup,
&[
"who", "when", "which", "find", "list", "show", "get", "fetch", "retrieve",
],
&["tell me about", "find all", "list all"],
1.0,
),
IntentRule::new(
Intent::Navigation,
&[
"navigate", "browse", "explore", "graph", "node", "edge", "traverse", "path",
],
&["show graph", "browse to", "navigate to"],
0.9,
),
IntentRule::new(
Intent::Help,
&[
"help", "assist", "support", "tutorial", "guide", "manual", "howto",
],
&["how to", "how do i", "what can you", "show me how"],
2.0,
),
IntentRule::new(
Intent::Greeting,
&[
"hello",
"hi",
"hey",
"greetings",
"howdy",
"morning",
"afternoon",
"evening",
],
&["good morning", "good afternoon", "good evening"],
0.8,
),
IntentRule::new(
Intent::Farewell,
&[
"bye", "goodbye", "farewell", "ciao", "later", "thanks", "thank",
],
&["see you", "take care", "good bye"],
1.5,
),
];
Self { rules }
}
pub fn classify(&self, text: &str) -> ClassificationResult {
let normalised = Self::normalize_text(text);
let mut best_score = 0.0_f64;
let mut best_intent = Intent::Unknown;
let mut best_kws: Vec<String> = vec![];
let mut best_pats: Vec<String> = vec![];
for rule in &self.rules {
let (score, kws, pats) = self.score_rule(rule, &normalised);
if score > best_score {
best_score = score;
best_intent = rule.intent.clone();
best_kws = kws;
best_pats = pats;
}
}
let confidence = (best_score / 10.0_f64).min(1.0);
ClassificationResult {
intent: best_intent,
confidence,
matched_keywords: best_kws,
matched_patterns: best_pats,
}
}
pub fn add_rule(&mut self, rule: IntentRule) {
self.rules.push(rule);
}
pub fn score_rule(&self, rule: &IntentRule, text: &str) -> (f64, Vec<String>, Vec<String>) {
let mut score = 0.0_f64;
let mut matched_kws: Vec<String> = Vec::new();
let mut matched_pats: Vec<String> = Vec::new();
for kw in &rule.keywords {
if contains_word(text, kw) {
score += rule.weight;
matched_kws.push(kw.clone());
}
}
for pat in &rule.patterns {
if text.contains(pat.to_lowercase().as_str()) {
score += rule.weight * 2.0;
matched_pats.push(pat.clone());
}
}
(score, matched_kws, matched_pats)
}
pub fn normalize_text(text: &str) -> String {
text.chars()
.map(|c| {
if c.is_alphanumeric() || c == ' ' || c == '_' {
c.to_ascii_lowercase()
} else {
' '
}
})
.collect::<String>()
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
}
pub fn batch_classify(&self, texts: &[&str]) -> Vec<ClassificationResult> {
texts.iter().map(|t| self.classify(t)).collect()
}
pub fn top_n(&self, text: &str, n: usize) -> Vec<(Intent, f64)> {
let normalised = Self::normalize_text(text);
let mut scores: Vec<(Intent, f64)> = self
.rules
.iter()
.map(|rule| {
let (score, _, _) = self.score_rule(rule, &normalised);
let confidence = (score / 10.0_f64).min(1.0);
(rule.intent.clone(), confidence)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(n);
scores
}
}
fn contains_word(text: &str, word: &str) -> bool {
text.split_whitespace().any(|t| t == word)
}
#[cfg(test)]
mod tests {
use super::*;
fn clf() -> IntentClassifier {
IntentClassifier::new()
}
#[test]
fn test_normalize_lowercase() {
assert_eq!(
IntentClassifier::normalize_text("HELLO World"),
"hello world"
);
}
#[test]
fn test_normalize_strips_punctuation() {
let n = IntentClassifier::normalize_text("Hello, world!");
assert!(!n.contains(','));
assert!(!n.contains('!'));
}
#[test]
fn test_normalize_collapses_spaces() {
let n = IntentClassifier::normalize_text("foo bar");
assert_eq!(n, "foo bar");
}
#[test]
fn test_normalize_empty() {
assert_eq!(IntentClassifier::normalize_text(""), "");
}
#[test]
fn test_classify_sparql_select() {
let result = clf().classify("SELECT * WHERE { ?s ?p ?o }");
assert_eq!(result.intent, Intent::SparqlQuery);
assert!(result.confidence > 0.0);
}
#[test]
fn test_classify_sparql_ask() {
let result = clf().classify("ask where the subject is defined");
assert_eq!(result.intent, Intent::SparqlQuery);
}
#[test]
fn test_classify_sparql_prefix() {
let result = clf().classify("Can you run a sparql query for me?");
assert_eq!(result.intent, Intent::SparqlQuery);
}
#[test]
fn test_classify_rdf_insert() {
let result = clf().classify("INSERT DATA { <s> <p> <o> }");
assert_eq!(result.intent, Intent::RdfInsert);
}
#[test]
fn test_classify_add_triple() {
let result = clf().classify("add triple <a> <b> <c>");
assert_eq!(result.intent, Intent::RdfInsert);
}
#[test]
fn test_classify_rdf_delete() {
let result = clf().classify("DELETE DATA { <s> <p> <o> }");
assert_eq!(result.intent, Intent::RdfDelete);
}
#[test]
fn test_classify_remove_triple() {
let result = clf().classify("remove triple <a> <b>");
assert_eq!(result.intent, Intent::RdfDelete);
}
#[test]
fn test_classify_schema_question_what_is() {
let result = clf().classify("what is the ontology for Person class?");
assert_eq!(result.intent, Intent::SchemaQuestion);
}
#[test]
fn test_classify_schema_class() {
let result = clf().classify("define the class hierarchy");
assert_eq!(result.intent, Intent::SchemaQuestion);
}
#[test]
fn test_classify_fact_lookup_who() {
let result = clf().classify("who created this resource");
assert_eq!(result.intent, Intent::FactLookup);
}
#[test]
fn test_classify_fact_lookup_find() {
let result = clf().classify("find all authors in the graph");
assert_eq!(result.intent, Intent::FactLookup);
}
#[test]
fn test_classify_fact_lookup_list() {
let result = clf().classify("list all available datasets");
assert_eq!(result.intent, Intent::FactLookup);
}
#[test]
fn test_classify_help() {
let result = clf().classify("help me with SPARQL");
assert_eq!(result.intent, Intent::Help);
}
#[test]
fn test_classify_how_to() {
let result = clf().classify("how to write a SPARQL query");
assert_eq!(result.intent, Intent::Help);
}
#[test]
fn test_classify_hello() {
let result = clf().classify("Hello!");
assert_eq!(result.intent, Intent::Greeting);
}
#[test]
fn test_classify_good_morning() {
let result = clf().classify("Good morning everyone");
assert_eq!(result.intent, Intent::Greeting);
}
#[test]
fn test_classify_hi() {
let result = clf().classify("Hi there");
assert_eq!(result.intent, Intent::Greeting);
}
#[test]
fn test_classify_bye() {
let result = clf().classify("Bye for now");
assert_eq!(result.intent, Intent::Farewell);
}
#[test]
fn test_classify_goodbye() {
let result = clf().classify("Goodbye see you later");
assert_eq!(result.intent, Intent::Farewell);
}
#[test]
fn test_classify_thanks() {
let result = clf().classify("thanks goodbye");
assert_eq!(result.intent, Intent::Farewell);
}
#[test]
fn test_classify_unknown() {
let result = clf().classify("xyzzy plugh");
assert_eq!(result.intent, Intent::Unknown);
assert!((result.confidence).abs() < 1e-9);
}
#[test]
fn test_classify_confidence_in_range() {
let result = clf().classify("SELECT * WHERE { ?s ?p ?o }");
assert!(result.confidence >= 0.0 && result.confidence <= 1.0);
}
#[test]
fn test_add_custom_rule() {
let mut c = clf();
c.add_rule(IntentRule::new(Intent::Navigation, &["teleport"], &[], 5.0));
let result = c.classify("teleport me to the knowledge graph");
assert_eq!(result.intent, Intent::Navigation);
}
#[test]
fn test_matched_keywords_non_empty() {
let result = clf().classify("select all triples where subject is known");
assert!(!result.matched_keywords.is_empty());
}
#[test]
fn test_batch_classify_correct_length() {
let texts = vec!["hello", "select * where {?s ?p ?o}", "bye"];
let results = clf().batch_classify(&texts);
assert_eq!(results.len(), 3);
}
#[test]
fn test_batch_classify_intents() {
let texts = vec!["hello", "SELECT * WHERE { ?s ?p ?o }"];
let results = clf().batch_classify(&texts);
assert_eq!(results[0].intent, Intent::Greeting);
assert_eq!(results[1].intent, Intent::SparqlQuery);
}
#[test]
fn test_top_n_correct_length() {
let result = clf().top_n("select where sparql", 3);
assert!(result.len() <= 3);
}
#[test]
fn test_top_n_sorted_descending() {
let result = clf().top_n("select all from where construct ask", 5);
for w in result.windows(2) {
assert!(w[0].1 >= w[1].1, "not sorted: {:?} vs {:?}", w[0], w[1]);
}
}
#[test]
fn test_top_n_zero() {
let result = clf().top_n("hello world", 0);
assert!(result.is_empty());
}
#[test]
fn test_score_rule_keyword_match() {
let c = clf();
let rule = IntentRule::new(Intent::Help, &["help"], &[], 1.0);
let (score, kws, pats) = c.score_rule(&rule, "help me please");
assert!(score > 0.0);
assert!(kws.contains(&"help".to_string()));
assert!(pats.is_empty());
}
#[test]
fn test_score_rule_pattern_match() {
let c = clf();
let rule = IntentRule::new(Intent::Help, &[], &["how to"], 1.0);
let (score, kws, pats) = c.score_rule(&rule, "how to write a query");
assert!(score > 0.0);
assert!(kws.is_empty());
assert!(!pats.is_empty());
}
#[test]
fn test_score_rule_no_match() {
let c = clf();
let rule = IntentRule::new(Intent::Greeting, &["hello"], &[], 1.0);
let (score, _, _) = c.score_rule(&rule, "delete the graph");
assert!((score).abs() < 1e-9);
}
#[test]
fn test_intent_display() {
assert_eq!(format!("{}", Intent::SparqlQuery), "SparqlQuery");
assert_eq!(format!("{}", Intent::Unknown), "Unknown");
}
}