radiate-gp 1.2.15

Extensions for radiate. Genetic Programming implementations for graphs (neural networks) and trees.
Documentation
use crate::{Op, Tree, TreeNode};

pub trait Expression<T> {
    fn parse(expr: &str) -> Result<T, String>;
}

impl Expression<Tree<Op<f32>>> for Tree<Op<f32>> {
    fn parse(expr: &str) -> Result<Tree<Op<f32>>, String> {
        parse(expr).map(|node| Tree::new(node))
    }
}

fn parse(expr: &str) -> Result<TreeNode<Op<f32>>, String> {
    let tokens = tokenize(expr);
    let mut pos = 0;
    parse_expression(&tokens, &mut pos)
}

#[derive(Debug, Clone, PartialEq)]
enum Token {
    Number(f32),
    Identifier(String, Option<usize>),
    Plus,
    Minus,
    Multiply,
    Divide,
    Power,
    LParen,
    RParen,
    EOF,
}

fn tokenize(expression: &str) -> Vec<Token> {
    let mut tokens = Vec::new();
    let mut chars = expression.chars().peekable();
    let mut vars = Vec::new();

    while let Some(&ch) = chars.peek() {
        match ch {
            ' ' | '\t' | '\n' => {
                chars.next();
            }
            '0'..='9' | '.' => {
                let mut num = String::new();
                while let Some(&ch) = chars.peek() {
                    if ch.is_ascii_digit() || ch == '.' {
                        num.push(ch);
                        chars.next();
                    } else {
                        break;
                    }
                }
                tokens.push(Token::Number(num.parse().unwrap()));
            }
            'a'..='z' | 'A'..='Z' => {
                let mut ident = String::new();
                while let Some(&ch) = chars.peek() {
                    if ch.is_alphanumeric() || ch == '_' {
                        ident.push(ch);
                        chars.next();
                    } else {
                        break;
                    }
                }
                tokens.push(Token::Identifier(ident.clone(), None));
                vars.push(ident);
            }
            '+' => {
                chars.next();
                tokens.push(Token::Plus);
            }
            '-' => {
                chars.next();
                tokens.push(Token::Minus);
            }
            '*' => {
                chars.next();
                tokens.push(Token::Multiply);
            }
            '/' => {
                chars.next();
                tokens.push(Token::Divide);
            }
            '^' => {
                chars.next();
                tokens.push(Token::Power);
            }
            '(' => {
                chars.next();
                tokens.push(Token::LParen);
            }
            ')' => {
                chars.next();
                tokens.push(Token::RParen);
            }
            _ => panic!("Unexpected character: {}", ch),
        }
    }

    tokens.push(Token::EOF);
    vars.dedup();
    vars.sort();

    for i in 0..tokens.len() {
        if matches!(tokens[i], Token::Identifier(_, _)) {
            let name = match &tokens[i] {
                Token::Identifier(name, _) => name,
                _ => unreachable!(),
            };
            let index = vars.iter().position(|v| v == name).unwrap();
            tokens[i] = Token::Identifier(name.clone(), Some(index));
        }
    }

    tokens
}

fn parse_expression(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
    let mut node = parse_term(tokens, pos)?;

    while let Some(token) = tokens.get(*pos) {
        match token {
            Token::Plus | Token::Minus => {
                let op = token.clone();
                *pos += 1;
                let right = parse_term(tokens, pos)?;
                node = TreeNode::new(match op {
                    Token::Plus => Op::add(),
                    Token::Minus => Op::sub(),
                    _ => unreachable!(),
                })
                .attach(node)
                .attach(right);
            }
            _ => break,
        }
    }

    Ok(node)
}

fn parse_term(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
    let mut node = parse_power(tokens, pos)?;

    while let Some(token) = tokens.get(*pos) {
        match token {
            Token::Multiply | Token::Divide => {
                let op = token.clone();
                *pos += 1;
                let right = parse_power(tokens, pos)?;
                node = TreeNode::new(match op {
                    Token::Multiply => Op::mul(),
                    Token::Divide => Op::div(),
                    _ => unreachable!(),
                })
                .attach(node)
                .attach(right);
            }
            _ => break,
        }
    }

    Ok(node)
}

fn parse_power(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
    let mut node = parse_factor(tokens, pos)?;

    if let Some(Token::Power) = tokens.get(*pos) {
        *pos += 1;
        let right = parse_power(tokens, pos)?;
        node = TreeNode::new(Op::pow()).attach(node).attach(right);
    }

    Ok(node)
}

fn parse_factor(tokens: &[Token], pos: &mut usize) -> Result<TreeNode<Op<f32>>, String> {
    match tokens.get(*pos) {
        Some(Token::Minus) => {
            *pos += 1;
            Ok(TreeNode::new(Op::neg()).attach(parse_factor(tokens, pos)?))
        }
        Some(Token::Plus) => {
            *pos += 1;
            parse_factor(tokens, pos)
        }
        Some(Token::Number(n)) => {
            *pos += 1;
            Ok(TreeNode::new(Op::constant(*n)))
        }
        Some(Token::Identifier(_, var)) => {
            *pos += 1;
            Ok(TreeNode::new(Op::var(var.unwrap())))
        }
        Some(Token::LParen) => {
            *pos += 1;
            let node = parse_expression(tokens, pos)?;
            if let Some(Token::RParen) = tokens.get(*pos) {
                *pos += 1;
                Ok(node)
            } else {
                Err("Expected ')'".to_string())
            }
        }
        token => Err(format!("Unexpected token: {:?}", token)),
    }
}

#[cfg(test)]
mod test {
    use crate::{Eval, Tree, ops::expr::Expression};

    #[test]
    fn test_tokenize() {
        let expr_str = "1 + 2 * (3 * 4)^5";
        if let Ok(tree) = Tree::parse(expr_str) {
            assert_eq!(tree.eval(&[]), 497665.0);
        } else {
            panic!("Failed to parse expression");
        }
    }

    #[test]
    fn test_tokenize_with_vars() {
        let expr_str = "a + b * (c * d)^e";

        if let Ok(tree) = Tree::parse(expr_str) {
            assert_eq!(tree.eval(&[1.0, 2.0, 3.0, 4.0, 5.0]), 497665.0);
        } else {
            panic!("Failed to parse expression");
        }
    }

    #[test]
    fn test_tokenize_with_vars_and_negation() {
        let expr_str = "5 - x * (34 * 3)^2";

        if let Ok(tree) = Tree::parse(expr_str) {
            assert_eq!(tree.eval(&[3.0]), -31207.0);
        } else {
            panic!("Failed to parse expression");
        }
    }

    #[test]
    fn test_tokenize_with_vars_and_negation_and_parens() {
        let comp = |x: f32| 4.0 * x.powf(3.0) - 3.0 * x.powf(2.0) + x;

        let expr_str = "4 * x^3 - 3 * x^2 + x";

        if let Ok(tree) = Tree::parse(expr_str) {
            let mut input = -1.0;
            for _ in -10..10 {
                input += 0.1;
                let output = tree.eval(&[input]);
                assert!((output - comp(input)).abs() < 0.0001);
            }
        } else {
            panic!("Failed to parse expression");
        }
    }
}