use crate::internal::{ast::{Family, Term}, errors::ParseError, lexer::Token};
pub fn parse_formula<'a>(
tokens: &'a [(Token, &'a str)],
pos: &mut usize,
) -> Result<(String, Vec<Term>, bool, Option<Family>), ParseError> {
let response = crate::internal::parse_response::parse_response(tokens, pos)?;
crate::internal::expect::expect(tokens, pos, |t| matches!(t, Token::Tilde), "~")?;
let (terms, has_intercept) = crate::internal::parse_rhs::parse_rhs(tokens, pos)?;
let mut family = None;
if crate::internal::matches::matches(tokens, pos, |t| matches!(t, Token::Comma)) {
crate::internal::expect::expect(tokens, pos, |t| matches!(t, Token::Family), "family")?;
crate::internal::expect::expect(tokens, pos, |t| matches!(t, Token::Equal), "=")?;
family = Some(crate::internal::parse_family::parse_family(tokens, pos)?);
}
Ok((response, terms, has_intercept, family))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::internal::lexer::Token;
#[test]
fn test_parse_formula_simple() {
let tokens = vec![
(Token::ColumnName, "y"),
(Token::Tilde, "~"),
(Token::ColumnName, "x")
];
let mut pos = 0;
let result = parse_formula(&tokens, &mut pos);
assert!(result.is_ok());
let (response, terms, has_intercept, family) = result.unwrap();
assert_eq!(response, "y");
assert_eq!(terms.len(), 1);
assert!(has_intercept);
assert!(family.is_none());
}
#[test]
fn test_parse_formula_with_multiple_terms() {
let tokens = vec![
(Token::ColumnName, "y"),
(Token::Tilde, "~"),
(Token::ColumnName, "x"),
(Token::Plus, "+"),
(Token::ColumnName, "z")
];
let mut pos = 0;
let result = parse_formula(&tokens, &mut pos);
assert!(result.is_ok());
let (response, terms, has_intercept, family) = result.unwrap();
assert_eq!(response, "y");
assert_eq!(terms.len(), 2);
assert!(has_intercept);
assert!(family.is_none());
}
#[test]
fn test_parse_formula_without_intercept() {
let tokens = vec![
(Token::ColumnName, "y"),
(Token::Tilde, "~"),
(Token::ColumnName, "x"),
(Token::Minus, "-"),
(Token::One, "1")
];
let mut pos = 0;
let result = parse_formula(&tokens, &mut pos);
assert!(result.is_ok());
let (response, terms, has_intercept, family) = result.unwrap();
assert_eq!(response, "y");
assert_eq!(terms.len(), 1);
assert!(!has_intercept);
assert!(family.is_none());
}
#[test]
fn test_parse_formula_with_family() {
let tokens = vec![
(Token::ColumnName, "y"),
(Token::Tilde, "~"),
(Token::ColumnName, "x"),
(Token::Comma, ","),
(Token::Family, "family"),
(Token::Equal, "="),
(Token::Gaussian, "gaussian")
];
let mut pos = 0;
let result = parse_formula(&tokens, &mut pos);
assert!(result.is_ok());
let (response, terms, has_intercept, family) = result.unwrap();
assert_eq!(response, "y");
assert_eq!(terms.len(), 1);
assert!(has_intercept);
assert!(family.is_some());
assert_eq!(family.unwrap(), Family::Gaussian);
}
#[test]
fn test_parse_formula_failure_missing_tilde() {
let tokens = vec![
(Token::ColumnName, "y"),
(Token::ColumnName, "x")
];
let mut pos = 0;
let result = parse_formula(&tokens, &mut pos);
assert!(result.is_err());
assert_eq!(pos, 1); }
#[test]
fn test_parse_formula_failure_missing_family_after_comma() {
let tokens = vec![
(Token::ColumnName, "y"),
(Token::Tilde, "~"),
(Token::ColumnName, "x"),
(Token::Comma, ",")
];
let mut pos = 0;
let result = parse_formula(&tokens, &mut pos);
assert!(result.is_err());
assert_eq!(pos, 4); }
#[test]
fn test_parse_formula_with_function_terms() {
let tokens = vec![
(Token::ColumnName, "y"),
(Token::Tilde, "~"),
(Token::Poly, "poly"),
(Token::FunctionStart, "("),
(Token::ColumnName, "x"),
(Token::Comma, ","),
(Token::Integer, "2"),
(Token::FunctionEnd, ")")
];
let mut pos = 0;
let result = parse_formula(&tokens, &mut pos);
assert!(result.is_ok());
let (response, terms, has_intercept, family) = result.unwrap();
assert_eq!(response, "y");
assert_eq!(terms.len(), 1);
assert!(has_intercept);
assert!(family.is_none());
}
#[test]
fn test_parse_formula_empty_rhs() {
let tokens = vec![
(Token::ColumnName, "y"),
(Token::Tilde, "~")
];
let mut pos = 0;
let result = parse_formula(&tokens, &mut pos);
assert!(result.is_ok());
let (response, terms, has_intercept, family) = result.unwrap();
assert_eq!(response, "y");
assert_eq!(terms.len(), 0);
assert!(has_intercept);
assert!(family.is_none());
}
}