use anyhow::{Result, anyhow};
use std::collections::HashMap;
use theory_core::{
BinaryOp, ConstValue, Constraint, ConstraintIR, Domain, Expr, TheoryTag, UnaryOp, VarId,
Variable, VariableMetadata,
};
pub fn parse_llm_response(response: &str, verbose: bool) -> Result<ConstraintIR> {
let (ir, _) = parse_llm_response_with_mapping(response, verbose)?;
Ok(ir)
}
pub fn parse_llm_response_with_mapping(
response: &str,
verbose: bool,
) -> Result<(ConstraintIR, HashMap<String, String>)> {
if verbose {
println!("Parsing LLM response:\n{}\n", response);
}
let mut ir = ConstraintIR::new();
let mut var_map: HashMap<String, VarId> = HashMap::new();
let sections = split_into_sections(response);
if let Some(var_section) = sections.get("Variables") {
parse_variables(&mut ir, &mut var_map, var_section, verbose)?;
}
if let Some(constraint_section) = sections.get("Constraints") {
parse_constraints(&mut ir, &var_map, constraint_section, verbose)?;
}
infer_theory_tags(&mut ir);
let name_mapping: HashMap<String, String> = var_map
.iter()
.map(|(name, var_id)| (name.clone(), format!("var_{}", var_id.0)))
.collect();
Ok((ir, name_mapping))
}
fn split_into_sections(response: &str) -> HashMap<String, Vec<String>> {
let mut sections: HashMap<String, Vec<String>> = HashMap::new();
let mut current_section: Option<String> = None;
for line in response.lines() {
let trimmed = line.trim();
if trimmed.ends_with(':') && !trimmed.starts_with('-') {
let section_name = trimmed.trim_end_matches(':').to_string();
current_section = Some(section_name.clone());
sections.insert(section_name, Vec::new());
} else if !trimmed.is_empty() && trimmed.starts_with('-') {
if let Some(ref section) = current_section {
sections
.get_mut(section)
.unwrap()
.push(trimmed.trim_start_matches('-').trim().to_string());
}
}
}
sections
}
fn parse_variables(
ir: &mut ConstraintIR,
var_map: &mut HashMap<String, VarId>,
lines: &[String],
verbose: bool,
) -> Result<()> {
for line in lines {
let parts: HashMap<&str, &str> = line
.split(',')
.filter_map(|part| {
let mut kv = part.split(':');
Some((kv.next()?.trim(), kv.next()?.trim()))
})
.collect();
let name = parts
.get("name")
.ok_or_else(|| anyhow!("Variable missing 'name' field"))?
.to_string();
let var_type = parts
.get("type")
.ok_or_else(|| anyhow!("Variable missing 'type' field"))?;
let domain = parse_domain(var_type, parts.get("domain"))?;
let var = Variable {
name: name.clone(),
domain,
metadata: VariableMetadata::default(),
};
let var_id = ir.add_variable(var);
var_map.insert(name.clone(), var_id);
if verbose {
println!(" Added variable: {} ({})", name, var_type);
}
}
Ok(())
}
fn parse_domain(var_type: &str, domain_spec: Option<&&str>) -> Result<Domain> {
match var_type.to_lowercase().as_str() {
"integer" | "int" => {
if let Some(spec) = domain_spec {
let cleaned = spec.trim_matches(|c| c == '[' || c == ']');
let bounds: Vec<&str> = cleaned.split(',').map(|s| s.trim()).collect();
if bounds.len() == 2 {
let min = bounds[0].parse::<i64>().ok();
let max = bounds[1].parse::<i64>().ok();
return Ok(Domain::Integer { min, max });
}
}
Ok(Domain::Integer {
min: None,
max: None,
})
}
"real" | "float" | "double" => {
if let Some(spec) = domain_spec {
let cleaned = spec.trim_matches(|c| c == '[' || c == ']');
let bounds: Vec<&str> = cleaned.split(',').map(|s| s.trim()).collect();
if bounds.len() == 2 {
let min = bounds[0].parse::<f64>().ok();
let max = bounds[1].parse::<f64>().ok();
return Ok(Domain::Real { min, max });
}
}
Ok(Domain::Real {
min: None,
max: None,
})
}
"boolean" | "bool" => Ok(Domain::Boolean),
_ => Err(anyhow!("Unknown variable type: {}", var_type)),
}
}
fn parse_constraints(
ir: &mut ConstraintIR,
var_map: &HashMap<String, VarId>,
lines: &[String],
verbose: bool,
) -> Result<()> {
for line in lines {
let constraint = parse_constraint_expression(line, var_map)?;
ir.add_constraint(constraint);
if verbose {
println!(" Added constraint: {}", line);
}
}
Ok(())
}
fn parse_constraint_expression(expr: &str, var_map: &HashMap<String, VarId>) -> Result<Constraint> {
let expr = expr.trim();
if let Some((lhs, rhs)) = expr.split_once('=') {
if !lhs.contains('<') && !rhs.contains('<') && !lhs.contains('>') && !rhs.contains('>') {
let lhs_expr = parse_expr(lhs.trim(), var_map)?;
let rhs_expr = parse_expr(rhs.trim(), var_map)?;
return Ok(Constraint::Equal {
lhs: lhs_expr,
rhs: rhs_expr,
});
}
}
if let Some((lhs, rhs)) = expr.split_once("<=").or_else(|| expr.split_once('≤')) {
let lhs_expr = parse_expr(lhs.trim(), var_map)?;
let rhs_expr = parse_expr(rhs.trim(), var_map)?;
return Ok(Constraint::LessEqual {
lhs: lhs_expr,
rhs: rhs_expr,
});
}
if let Some((lhs, rhs)) = expr.split_once(">=").or_else(|| expr.split_once('≥')) {
let lhs_expr = parse_expr(lhs.trim(), var_map)?;
let rhs_expr = parse_expr(rhs.trim(), var_map)?;
return Ok(Constraint::LessEqual {
lhs: rhs_expr,
rhs: lhs_expr,
});
}
if let Some((lhs, rhs)) = expr.split_once('<') {
let lhs_expr = parse_expr(lhs.trim(), var_map)?;
let rhs_expr = parse_expr(rhs.trim(), var_map)?;
return Ok(Constraint::Less {
lhs: lhs_expr,
rhs: rhs_expr,
});
}
if let Some((lhs, rhs)) = expr.split_once('>') {
let lhs_expr = parse_expr(lhs.trim(), var_map)?;
let rhs_expr = parse_expr(rhs.trim(), var_map)?;
return Ok(Constraint::Less {
lhs: rhs_expr,
rhs: lhs_expr,
});
}
Err(anyhow!("Could not parse constraint: {}", expr))
}
fn find_operator_outside_parens<'a>(
expr: &'a str,
operators: &[char],
) -> Option<(&'a str, &'a str, char)> {
let mut depth = 0;
let chars: Vec<char> = expr.chars().collect();
for i in (0..chars.len()).rev() {
let ch = chars[i];
if ch == ')' {
depth += 1;
} else if ch == '(' {
depth -= 1;
}
if depth == 0 && operators.contains(&ch) {
let lhs = &expr[..i];
let rhs = &expr[i + 1..];
return Some((lhs.trim(), rhs.trim(), ch));
}
}
None
}
fn parse_expr(expr: &str, var_map: &HashMap<String, VarId>) -> Result<Expr> {
let expr = expr.trim();
if let Ok(val) = expr.parse::<i64>() {
return Ok(Expr::Const(ConstValue::Integer(val)));
}
if let Ok(val) = expr.parse::<f64>() {
return Ok(Expr::Const(ConstValue::Real(val)));
}
let expr_normalized = expr
.replace(" mod ", " % ")
.replace("^", " ^ ")
.replace("sqrt(", "√(")
.replace("sqrt ", "√ ");
let expr = expr_normalized.as_str();
if expr.starts_with('(') && expr.ends_with(')') {
let mut depth = 0;
let mut closes_at_end = false;
for (i, ch) in expr.chars().enumerate() {
if ch == '(' {
depth += 1;
} else if ch == ')' {
depth -= 1;
if depth == 0 && i == expr.len() - 1 {
closes_at_end = true;
} else if depth == 0 {
break;
}
}
}
if closes_at_end {
return parse_expr(&expr[1..expr.len() - 1], var_map);
}
}
if expr.starts_with('√') {
let inner = expr.trim_start_matches('√').trim();
let inner = if inner.starts_with('(') && inner.ends_with(')') {
&inner[1..inner.len() - 1]
} else {
inner
};
return Ok(Expr::Unary {
op: UnaryOp::Sqrt,
operand: Box::new(parse_expr(inner, var_map)?),
});
}
if let Some(inner) = expr.strip_prefix("abs(") {
if let Some(inner) = inner.strip_suffix(')') {
return Ok(Expr::Unary {
op: UnaryOp::Abs,
operand: Box::new(parse_expr(inner, var_map)?),
});
}
}
if let Some((lhs, rhs, op_char)) = find_operator_outside_parens(expr, &['+', '-']) {
if op_char == '-' && lhs.trim().is_empty() {
} else {
let op = if op_char == '+' {
BinaryOp::Add
} else {
BinaryOp::Sub
};
let lhs_expr = parse_expr(lhs, var_map)?;
let rhs_expr = parse_expr(rhs, var_map)?;
return Ok(Expr::Binary {
op,
lhs: Box::new(lhs_expr),
rhs: Box::new(rhs_expr),
});
}
}
if let Some((lhs, rhs, op_char)) = find_operator_outside_parens(expr, &['*', '/', '%']) {
let op = match op_char {
'*' => BinaryOp::Mul,
'/' => BinaryOp::Div,
'%' => BinaryOp::Mod,
_ => unreachable!(),
};
let lhs_expr = parse_expr(lhs, var_map)?;
let rhs_expr = parse_expr(rhs, var_map)?;
return Ok(Expr::Binary {
op,
lhs: Box::new(lhs_expr),
rhs: Box::new(rhs_expr),
});
}
if let Some((lhs, rhs, _op_char)) = find_operator_outside_parens(expr, &['^']) {
let lhs_expr = parse_expr(lhs, var_map)?;
let rhs_expr = parse_expr(rhs, var_map)?;
return Ok(Expr::Binary {
op: BinaryOp::Power,
lhs: Box::new(lhs_expr),
rhs: Box::new(rhs_expr),
});
}
if let Some(&var_id) = var_map.get(expr) {
return Ok(Expr::Var(var_id));
}
Err(anyhow!("Could not parse expression: {}", expr))
}
fn infer_theory_tags(ir: &mut ConstraintIR) {
let mut has_integer = false;
let mut has_real = false;
let mut has_nonlinear = false;
for var in ir.variables.values() {
match var.domain {
Domain::Integer { .. } => has_integer = true,
Domain::Real { .. } => has_real = true,
_ => {}
}
}
for constraint in &ir.constraints {
if has_nonlinear_ops(constraint) {
has_nonlinear = true;
break;
}
}
if has_integer && !has_real && !has_nonlinear {
ir.add_theory_tag(TheoryTag::LIA); } else if has_real && !has_nonlinear {
ir.add_theory_tag(TheoryTag::LRA); } else if has_nonlinear {
ir.add_theory_tag(TheoryTag::NLA); }
}
fn has_nonlinear_ops(constraint: &Constraint) -> bool {
match constraint {
Constraint::Equal { lhs, rhs }
| Constraint::LessEqual { lhs, rhs }
| Constraint::Less { lhs, rhs } => {
expr_has_nonlinear_ops(lhs) || expr_has_nonlinear_ops(rhs)
}
Constraint::Not { constraint } => has_nonlinear_ops(constraint),
Constraint::And { constraints } | Constraint::Or { constraints } => {
constraints.iter().any(has_nonlinear_ops)
}
_ => false,
}
}
fn expr_has_nonlinear_ops(expr: &Expr) -> bool {
match expr {
Expr::Binary { op, lhs, rhs } => {
let is_mult_of_vars = matches!(op, BinaryOp::Mul)
&& matches!(lhs.as_ref(), Expr::Var(_))
&& matches!(rhs.as_ref(), Expr::Var(_));
is_mult_of_vars || expr_has_nonlinear_ops(lhs) || expr_has_nonlinear_ops(rhs)
}
_ => false,
}
}
pub fn create_simple_fallback() -> ConstraintIR {
ConstraintIR::new_simple_test()
}