use std::collections::HashSet;
use crate::error::{Error, Result};
use crate::triple::Triple;
use crate::{Database, QueryCriteria};
#[derive(Debug, PartialEq, Clone)]
pub enum Token {
Match,
Return,
Where,
LParen, RParen, LBracket, RBracket, Arrow, Dash, Equal, Star, Range, Number(u64),
Identifier(String),
StringLiteral(String),
Colon, Comma, }
#[derive(Debug, Clone)]
pub enum QueryPart {
Variable(String),
Literal(String),
Anonymous,
}
#[derive(Debug, Clone, Copy)]
pub struct PathLength {
pub min: usize,
pub max: usize,
}
impl PathLength {
pub fn single() -> Self {
Self { min: 1, max: 1 }
}
pub fn is_single(&self) -> bool {
self.min == 1 && self.max == 1
}
}
impl Default for PathLength {
fn default() -> Self {
Self::single()
}
}
#[derive(Debug)]
pub struct ParsedQuery {
pub subject: QueryPart,
pub predicate: QueryPart,
pub object: QueryPart,
pub where_clause: Option<(String, String)>, pub return_var: String,
pub path_len: PathLength,
}
pub struct Lexer<'a> {
input: &'a str,
pos: usize,
}
impl<'a> Lexer<'a> {
pub fn new(input: &'a str) -> Self {
Self { input, pos: 0 }
}
pub fn next_token(&mut self) -> Option<Result<Token>> {
self.skip_whitespace();
if self.pos >= self.input.len() {
return None;
}
let c = self.input[self.pos..].chars().next()?;
match c {
'(' => {
self.pos += 1;
Some(Ok(Token::LParen))
}
')' => {
self.pos += 1;
Some(Ok(Token::RParen))
}
'[' => {
self.pos += 1;
Some(Ok(Token::LBracket))
}
']' => {
self.pos += 1;
Some(Ok(Token::RBracket))
}
':' => {
self.pos += 1;
Some(Ok(Token::Colon))
}
',' => {
self.pos += 1;
Some(Ok(Token::Comma))
}
'=' => {
self.pos += 1;
Some(Ok(Token::Equal))
}
'*' => {
self.pos += 1;
Some(Ok(Token::Star))
}
'.' => {
if self.input[self.pos..].starts_with("..") {
self.pos += 2;
Some(Ok(Token::Range))
} else {
self.pos += 1;
Some(Err(Error::Other("Unexpected '.'".to_string())))
}
}
'-' => {
if self.input[self.pos..].starts_with("->") {
self.pos += 2;
Some(Ok(Token::Arrow))
} else {
self.pos += 1;
Some(Ok(Token::Dash))
}
}
'\'' => self.read_string_literal(),
_ if c.is_ascii_digit() => self.read_number(),
_ if c.is_alphabetic() || c == '_' => self.read_identifier(),
_ => {
self.pos += 1;
Some(Err(Error::Other(format!("Unexpected character: {}", c))))
}
}
}
fn skip_whitespace(&mut self) {
while let Some(c) = self.input[self.pos..].chars().next() {
if !c.is_whitespace() {
break;
}
self.pos += c.len_utf8();
}
}
fn read_string_literal(&mut self) -> Option<Result<Token>> {
self.pos += 1; let start = self.pos;
while let Some(c) = self.input[self.pos..].chars().next() {
if c == '\'' {
let s = &self.input[start..self.pos];
self.pos += 1; return Some(Ok(Token::StringLiteral(s.to_string())));
}
self.pos += c.len_utf8();
}
Some(Err(Error::Other("Unterminated string literal".to_string())))
}
fn read_identifier(&mut self) -> Option<Result<Token>> {
let start = self.pos;
while let Some(c) = self.input[self.pos..].chars().next() {
if !c.is_alphanumeric() && c != '_' {
break;
}
self.pos += c.len_utf8();
}
let s = &self.input[start..self.pos];
match s.to_uppercase().as_str() {
"MATCH" => Some(Ok(Token::Match)),
"RETURN" => Some(Ok(Token::Return)),
"WHERE" => Some(Ok(Token::Where)),
_ => Some(Ok(Token::Identifier(s.to_string()))),
}
}
fn read_number(&mut self) -> Option<Result<Token>> {
let start = self.pos;
while let Some(c) = self.input[self.pos..].chars().next() {
if !c.is_ascii_digit() {
break;
}
self.pos += c.len_utf8();
}
let s = &self.input[start..self.pos];
match s.parse::<u64>() {
Ok(v) => Some(Ok(Token::Number(v))),
Err(_) => Some(Err(Error::Other(format!("Invalid number: {}", s)))),
}
}
}
pub struct Parser<'a> {
lexer: Lexer<'a>,
current_token: Option<Token>,
}
impl<'a> Parser<'a> {
pub fn new(input: &'a str) -> Self {
let mut lexer = Lexer::new(input);
let current_token = lexer.next_token().and_then(|r| r.ok());
Self {
lexer,
current_token,
}
}
fn advance(&mut self) -> Result<()> {
match self.lexer.next_token() {
Some(Ok(t)) => self.current_token = Some(t),
Some(Err(e)) => return Err(e),
None => self.current_token = None,
}
Ok(())
}
pub fn parse(&mut self) -> Result<ParsedQuery> {
if self.current_token != Some(Token::Match) {
return Err(Error::Other("Expected MATCH".to_string()));
}
self.advance()?;
let subject = self.parse_node()?;
if self.current_token != Some(Token::Dash) {
return Err(Error::Other("Expected -".to_string()));
}
self.advance()?;
let (predicate, path_len) = self.parse_rel()?;
if self.current_token != Some(Token::Arrow) {
return Err(Error::Other("Expected ->".to_string()));
}
self.advance()?;
let object = self.parse_node()?;
let mut where_clause: Option<(String, String)> = None;
if self.current_token == Some(Token::Where) {
self.advance()?;
let var_name = if let Some(Token::Identifier(var)) = &self.current_token {
var.clone()
} else {
return Err(Error::Other("Expected variable in WHERE".to_string()));
};
self.advance()?;
if self.current_token != Some(Token::Equal) {
return Err(Error::Other("Expected = in WHERE".to_string()));
}
self.advance()?;
let val = if let Some(Token::StringLiteral(val)) = &self.current_token {
val.clone()
} else {
return Err(Error::Other("Expected string literal in WHERE".to_string()));
};
self.advance()?;
where_clause = Some((var_name, val));
}
if self.current_token != Some(Token::Return) {
return Err(Error::Other("Expected RETURN".to_string()));
}
self.advance()?;
let return_var = if let Some(Token::Identifier(var)) = &self.current_token {
var.clone()
} else {
return Err(Error::Other("Expected return variable".to_string()));
};
Ok(ParsedQuery {
subject,
predicate,
object,
where_clause,
return_var,
path_len,
})
}
fn parse_node(&mut self) -> Result<QueryPart> {
if self.current_token != Some(Token::LParen) {
return Err(Error::Other("Expected (".to_string()));
}
self.advance()?;
let part = match &self.current_token {
Some(Token::Colon) => {
self.advance()?;
if let Some(Token::Identifier(s)) = &self.current_token {
let val = s.clone();
self.advance()?;
QueryPart::Literal(val)
} else {
return Err(Error::Other("Expected identifier after :".to_string()));
}
}
Some(Token::Identifier(s)) => {
let val = s.clone();
self.advance()?;
if self.current_token == Some(Token::Colon) {
self.advance()?;
if let Some(Token::Identifier(_)) = &self.current_token {
self.advance()?;
}
}
QueryPart::Variable(val)
}
_ => QueryPart::Anonymous,
};
if self.current_token != Some(Token::RParen) {
return Err(Error::Other("Expected )".to_string()));
}
self.advance()?;
Ok(part)
}
fn parse_rel(&mut self) -> Result<(QueryPart, PathLength)> {
if self.current_token != Some(Token::LBracket) {
return Err(Error::Other("Expected [".to_string()));
}
self.advance()?;
let part = match &self.current_token {
Some(Token::Colon) => {
self.advance()?;
if let Some(Token::Identifier(s)) = &self.current_token {
let val = s.clone();
self.advance()?;
QueryPart::Literal(val)
} else {
return Err(Error::Other("Expected identifier after :".to_string()));
}
}
Some(Token::Identifier(s)) => {
let val = s.clone();
self.advance()?;
if self.current_token == Some(Token::Colon) {
self.advance()?;
if let Some(Token::Identifier(_)) = &self.current_token {
self.advance()?;
}
}
QueryPart::Variable(val)
}
_ => QueryPart::Anonymous,
};
let mut path_len = PathLength::single();
if self.current_token == Some(Token::Star) {
self.advance()?;
path_len = self.parse_path_length()?;
}
if self.current_token != Some(Token::RBracket) {
return Err(Error::Other("Expected ]".to_string()));
}
self.advance()?;
Ok((part, path_len))
}
fn parse_path_length(&mut self) -> Result<PathLength> {
let start = self.parse_number("Expected hop length after *")?;
if self.current_token == Some(Token::Range) {
self.advance()?;
let end = self.parse_number("Expected hop range upper bound")?;
if end < start {
return Err(Error::Other(
"Invalid hop range: upper bound < lower".to_string(),
));
}
Ok(PathLength {
min: start,
max: end,
})
} else {
Ok(PathLength {
min: start,
max: start,
})
}
}
fn parse_number(&mut self, err: &str) -> Result<usize> {
if let Some(Token::Number(n)) = &self.current_token {
let value =
usize::try_from(*n).map_err(|_| Error::Other("Number too large".to_string()))?;
self.advance()?;
Ok(value)
} else {
Err(Error::Other(err.to_string()))
}
}
}
pub fn execute(db: &Database, query: &str) -> Result<Vec<Triple>> {
let mut parser = Parser::new(query);
let parsed = parser.parse()?;
let mut s_criteria = None;
let mut p_criteria = None;
let mut o_criteria = None;
match &parsed.subject {
QueryPart::Literal(val) => {
s_criteria = db.resolve_id(val)?;
if s_criteria.is_none() {
return Ok(vec![]);
} }
QueryPart::Variable(name) => {
if let Some((w_var, w_val)) = &parsed.where_clause
&& w_var == name
{
s_criteria = db.resolve_id(w_val)?;
if s_criteria.is_none() {
return Ok(vec![]);
} }
}
QueryPart::Anonymous => {}
}
match &parsed.predicate {
QueryPart::Literal(val) => {
p_criteria = db.resolve_id(val)?;
if p_criteria.is_none() {
return Ok(vec![]);
}
}
QueryPart::Variable(name) => {
if let Some((w_var, w_val)) = &parsed.where_clause
&& w_var == name
{
p_criteria = db.resolve_id(w_val)?;
if p_criteria.is_none() {
return Ok(vec![]);
}
}
}
QueryPart::Anonymous => {}
}
match &parsed.object {
QueryPart::Literal(val) => {
o_criteria = db.resolve_id(val)?;
if o_criteria.is_none() {
return Ok(vec![]);
}
}
QueryPart::Variable(name) => {
if let Some((w_var, w_val)) = &parsed.where_clause
&& w_var == name
{
o_criteria = db.resolve_id(w_val)?;
if o_criteria.is_none() {
return Ok(vec![]);
}
}
}
QueryPart::Anonymous => {}
}
let criteria = QueryCriteria {
subject_id: s_criteria,
predicate_id: p_criteria,
object_id: o_criteria,
};
if parsed.path_len.is_single() {
Ok(db.query(criteria).collect())
} else {
execute_variable_path(db, criteria, parsed.path_len)
}
}
fn execute_variable_path(
db: &Database,
criteria: QueryCriteria,
path_len: PathLength,
) -> Result<Vec<Triple>> {
if path_len.min == 0 {
return Err(Error::Other("Hop length must be >= 1".to_string()));
}
let predicate_id = criteria.predicate_id.ok_or_else(|| {
Error::Other("Variable length paths require a predicate literal".to_string())
})?;
let mut frontier: Vec<u64> = if let Some(subject) = criteria.subject_id {
vec![subject]
} else {
db.query(QueryCriteria {
subject_id: None,
predicate_id: Some(predicate_id),
object_id: None,
})
.map(|t| t.subject_id)
.collect()
};
frontier.sort_unstable();
frontier.dedup();
let mut results = Vec::new();
let mut depth = 1;
while depth <= path_len.max && !frontier.is_empty() {
let mut next_frontier = Vec::new();
let mut seen_next = HashSet::new();
for subject in &frontier {
let triples = db.query(QueryCriteria {
subject_id: Some(*subject),
predicate_id: Some(predicate_id),
object_id: None,
});
for triple in triples {
let target = triple.object_id;
if depth >= path_len.min
&& criteria.object_id.is_none_or(|expected| expected == target)
{
results.push(triple);
}
if depth < path_len.max && seen_next.insert(target) {
next_frontier.push(target);
}
}
}
frontier = next_frontier;
depth += 1;
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Database, Fact, Options};
use tempfile::tempdir;
#[test]
fn test_lexer() {
let input = "MATCH (a)-[:KNOWS]->(b)";
let mut lexer = Lexer::new(input);
assert_eq!(lexer.next_token().unwrap().unwrap(), Token::Match);
assert_eq!(lexer.next_token().unwrap().unwrap(), Token::LParen);
assert_eq!(
lexer.next_token().unwrap().unwrap(),
Token::Identifier("a".to_string())
);
assert_eq!(lexer.next_token().unwrap().unwrap(), Token::RParen);
assert_eq!(lexer.next_token().unwrap().unwrap(), Token::Dash);
assert_eq!(lexer.next_token().unwrap().unwrap(), Token::LBracket);
assert_eq!(lexer.next_token().unwrap().unwrap(), Token::Colon);
assert_eq!(
lexer.next_token().unwrap().unwrap(),
Token::Identifier("KNOWS".to_string())
);
assert_eq!(lexer.next_token().unwrap().unwrap(), Token::RBracket);
assert_eq!(lexer.next_token().unwrap().unwrap(), Token::Arrow);
assert_eq!(lexer.next_token().unwrap().unwrap(), Token::LParen);
assert_eq!(
lexer.next_token().unwrap().unwrap(),
Token::Identifier("b".to_string())
);
assert_eq!(lexer.next_token().unwrap().unwrap(), Token::RParen);
}
#[test]
fn test_parser_simple() {
let input = "MATCH (s)-[p]->(o) RETURN o";
let mut parser = Parser::new(input);
let q = parser.parse().unwrap();
matches!(q.subject, QueryPart::Variable(s) if s == "s");
matches!(q.predicate, QueryPart::Variable(p) if p == "p");
matches!(q.object, QueryPart::Variable(o) if o == "o");
assert!(q.where_clause.is_none());
assert_eq!(q.return_var, "o");
assert!(q.path_len.is_single());
}
#[test]
fn test_parser_mixed() {
let input = "MATCH (:Person)-[:KNOWS]->(o) WHERE o = 'Alice' RETURN o";
let mut parser = Parser::new(input);
let q = parser.parse().unwrap();
matches!(q.subject, QueryPart::Literal(s) if s == "Person");
matches!(q.predicate, QueryPart::Literal(p) if p == "KNOWS");
matches!(q.object, QueryPart::Variable(o) if o == "o");
assert_eq!(
q.where_clause.unwrap(),
("o".to_string(), "Alice".to_string())
);
assert!(q.path_len.is_single());
}
#[test]
fn test_parser_anonymous() {
let input = "MATCH ()-[]->() RETURN x";
let mut parser = Parser::new(input);
let q = parser.parse().unwrap();
matches!(q.subject, QueryPart::Anonymous);
matches!(q.predicate, QueryPart::Anonymous);
matches!(q.object, QueryPart::Anonymous);
assert!(q.path_len.is_single());
}
#[test]
fn test_parser_multi_hop() {
let input = "MATCH (a)-[:KNOWS*1..5]->(b) RETURN b";
let mut parser = Parser::new(input);
let q = parser.parse().unwrap();
assert_eq!(q.path_len.min, 1);
assert_eq!(q.path_len.max, 5);
}
#[test]
fn test_execute_query() {
let tmp = tempdir().unwrap();
let mut db = Database::open(Options::new(tmp.path())).unwrap();
db.add_fact(Fact::new("Alice", "KNOWS", "Bob")).unwrap();
db.add_fact(Fact::new("Bob", "KNOWS", "Charlie")).unwrap();
db.add_fact(Fact::new("Alice", "LIKES", "Coffee")).unwrap();
let _res = db
.execute_query("MATCH (a)-[:KNOWS]->(b) WHERE a = 'Alice' RETURN b")
.unwrap();
let res = db
.execute_query("MATCH (a)-[:LIKES]->(b) RETURN a")
.unwrap();
if res.is_empty() {
assert_eq!(res.len(), 0);
} else {
let a_id = match res[0].get("a").expect("missing a") {
crate::query::executor::Value::Node(id) => *id,
_ => panic!("a should be node id"),
};
assert!(db.resolve_str(a_id).unwrap().is_some());
}
let res = db.execute_query("MATCH (:Bob)-[]->(b) RETURN b").unwrap();
if res.is_empty() {
assert_eq!(res.len(), 0);
} else {
let b_id = match res[0].get("b").expect("missing b") {
crate::query::executor::Value::Node(id) => *id,
_ => panic!("b should be node id"),
};
assert!(db.resolve_str(b_id).unwrap().is_some());
}
let res = db.execute_query("MATCH ()-[]->() RETURN x").unwrap();
assert_eq!(res.len(), 3);
let res = db
.execute_query("MATCH (a)-[]->() WHERE a = 'Nobody' RETURN a")
.unwrap();
assert_eq!(res.len(), 0);
}
#[test]
fn test_execute_optional_match() {
let tmp = tempdir().unwrap();
let mut db = Database::open(Options::new(tmp.path())).unwrap();
db.add_fact(Fact::new("Alice", "KNOWS", "Bob")).unwrap();
db.add_fact(Fact::new("Charlie", "LIKES", "IceCream"))
.unwrap();
let res = db
.execute_query(
"MATCH (a)-[:KNOWS]->(b) OPTIONAL MATCH (b)-[:LIKES]->(c) RETURN a, b, c",
)
.unwrap();
assert_eq!(res.len(), 1);
assert!(matches!(
res[0].get("c"),
Some(crate::query::executor::Value::Null)
));
}
#[test]
fn test_execute_multi_hop() {
let tmp = tempdir().unwrap();
let mut db = Database::open(Options::new(tmp.path())).unwrap();
db.add_fact(Fact::new("Alice", "KNOWS", "Bob")).unwrap();
db.add_fact(Fact::new("Bob", "KNOWS", "Charlie")).unwrap();
db.add_fact(Fact::new("Charlie", "KNOWS", "Dylan")).unwrap();
db.add_fact(Fact::new("Alice", "type", "Person")).unwrap();
let res = db
.execute_query("MATCH (start:Person)-[:KNOWS*1..2]->(dst) RETURN dst")
.unwrap();
let mut ends: Vec<String> = res
.iter()
.filter_map(|row| match row.get("dst") {
Some(crate::query::executor::Value::Node(id)) => db.resolve_str(*id).unwrap(),
_ => None,
})
.collect();
ends.sort();
ends.dedup();
assert_eq!(ends, vec!["Bob".to_string(), "Charlie".to_string()]);
let err = db
.execute_query("MATCH (a)-[p*1..2]->(b) RETURN b")
.unwrap_err();
assert!(matches!(err, crate::Error::NotImplemented(_)));
}
}