use super::ast::{Expression, Value};
use super::types::{FunctionSignature, VarType};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct StandardLibrary {
functions: HashMap<String, BuiltinFunction>,
patterns: HashMap<String, Pattern>,
templates: HashMap<String, Template>,
}
#[derive(Debug, Clone)]
pub struct BuiltinFunction {
pub name: String,
pub signature: FunctionSignature,
pub description: String,
pub implementation: FunctionImpl,
}
#[derive(Debug, Clone)]
pub enum FunctionImpl {
Native,
DSL { body: Expression },
}
#[derive(Debug, Clone)]
pub struct Pattern {
pub name: String,
pub description: String,
pub parameters: Vec<String>,
pub expansion: super::ast::AST,
}
#[derive(Debug, Clone)]
pub struct Template {
pub name: String,
pub description: String,
pub parameters: Vec<TemplateParam>,
pub body: String,
}
#[derive(Debug, Clone)]
pub struct TemplateParam {
pub name: String,
pub param_type: String,
pub default: Option<Value>,
}
impl StandardLibrary {
pub fn new() -> Self {
let mut stdlib = Self {
functions: HashMap::new(),
patterns: HashMap::new(),
templates: HashMap::new(),
};
stdlib.register_builtin_functions();
stdlib.register_common_patterns();
stdlib.register_templates();
stdlib
}
fn register_builtin_functions(&mut self) {
self.functions.insert(
"abs".to_string(),
BuiltinFunction {
name: "abs".to_string(),
signature: FunctionSignature {
param_types: vec![VarType::Continuous],
return_type: VarType::Continuous,
},
description: "Absolute value function".to_string(),
implementation: FunctionImpl::Native,
},
);
self.functions.insert(
"sqrt".to_string(),
BuiltinFunction {
name: "sqrt".to_string(),
signature: FunctionSignature {
param_types: vec![VarType::Continuous],
return_type: VarType::Continuous,
},
description: "Square root function".to_string(),
implementation: FunctionImpl::Native,
},
);
self.functions.insert(
"sum".to_string(),
BuiltinFunction {
name: "sum".to_string(),
signature: FunctionSignature {
param_types: vec![VarType::Array {
element_type: Box::new(VarType::Continuous),
dimensions: vec![0],
}],
return_type: VarType::Continuous,
},
description: "Sum aggregation function".to_string(),
implementation: FunctionImpl::Native,
},
);
}
fn register_common_patterns(&mut self) {
self.patterns.insert(
"all_different".to_string(),
Pattern {
name: "all_different".to_string(),
description: "Ensures all variables in a set take different values".to_string(),
parameters: vec!["variables".to_string()],
expansion: super::ast::AST::Program {
declarations: vec![],
objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
super::ast::Value::Number(0.0),
)),
constraints: vec![], },
},
);
self.patterns.insert(
"cardinality".to_string(),
Pattern {
name: "cardinality".to_string(),
description: "Constrains the number of true variables in a set".to_string(),
parameters: vec![
"variables".to_string(),
"min_count".to_string(),
"max_count".to_string(),
],
expansion: super::ast::AST::Program {
declarations: vec![],
objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
super::ast::Value::Number(0.0),
)),
constraints: vec![],
},
},
);
self.patterns.insert(
"at_most_one".to_string(),
Pattern {
name: "at_most_one".to_string(),
description: "Ensures at most one variable in a set is true".to_string(),
parameters: vec!["variables".to_string()],
expansion: super::ast::AST::Program {
declarations: vec![],
objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
super::ast::Value::Number(0.0),
)),
constraints: vec![],
},
},
);
self.patterns.insert(
"exactly_one".to_string(),
Pattern {
name: "exactly_one".to_string(),
description: "Ensures exactly one variable in a set is true".to_string(),
parameters: vec!["variables".to_string()],
expansion: super::ast::AST::Program {
declarations: vec![],
objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
super::ast::Value::Number(0.0),
)),
constraints: vec![],
},
},
);
}
fn register_templates(&mut self) {
self.templates.insert(
"tsp".to_string(),
Template {
name: "tsp".to_string(),
description: "Traveling Salesman Problem template".to_string(),
parameters: vec![
TemplateParam {
name: "n_cities".to_string(),
param_type: "integer".to_string(),
default: Some(Value::Number(4.0)),
},
TemplateParam {
name: "distance_matrix".to_string(),
param_type: "matrix".to_string(),
default: None,
},
],
body: r"
param n = {n_cities};
param distances = {distance_matrix};
var x[n, n] binary;
minimize sum(i in 0..n, j in 0..n: distances[i][j] * x[i,j]);
subject to
forall(i in 0..n): sum(j in 0..n: x[i,j]) == 1;
forall(j in 0..n): sum(i in 0..n: x[i,j]) == 1;
"
.to_string(),
},
);
self.templates.insert(
"graph_coloring".to_string(),
Template {
name: "graph_coloring".to_string(),
description: "Graph coloring problem template".to_string(),
parameters: vec![
TemplateParam {
name: "n_vertices".to_string(),
param_type: "integer".to_string(),
default: Some(Value::Number(5.0)),
},
TemplateParam {
name: "n_colors".to_string(),
param_type: "integer".to_string(),
default: Some(Value::Number(3.0)),
},
TemplateParam {
name: "edges".to_string(),
param_type: "array".to_string(),
default: None,
},
],
body: r"
param n_vertices = {n_vertices};
param n_colors = {n_colors};
param edges = {edges};
var color[n_vertices, n_colors] binary;
minimize sum(v in 0..n_vertices, c in 0..n_colors: c * color[v,c]);
subject to
forall(v in 0..n_vertices): sum(c in 0..n_colors: color[v,c]) == 1;
forall((u,v) in edges, c in 0..n_colors): color[u,c] + color[v,c] <= 1;
"
.to_string(),
},
);
self.templates.insert(
"knapsack".to_string(),
Template {
name: "knapsack".to_string(),
description: "0-1 Knapsack problem template".to_string(),
parameters: vec![
TemplateParam {
name: "n_items".to_string(),
param_type: "integer".to_string(),
default: Some(Value::Number(10.0)),
},
TemplateParam {
name: "weights".to_string(),
param_type: "array".to_string(),
default: None,
},
TemplateParam {
name: "values".to_string(),
param_type: "array".to_string(),
default: None,
},
TemplateParam {
name: "capacity".to_string(),
param_type: "number".to_string(),
default: Some(Value::Number(100.0)),
},
],
body: r"
param n = {n_items};
param weights = {weights};
param values = {values};
param capacity = {capacity};
var x[n] binary;
maximize sum(i in 0..n: values[i] * x[i]);
subject to
sum(i in 0..n: weights[i] * x[i]) <= capacity;
"
.to_string(),
},
);
self.templates.insert(
"max_cut".to_string(),
Template {
name: "max_cut".to_string(),
description: "Maximum cut problem template".to_string(),
parameters: vec![
TemplateParam {
name: "n_vertices".to_string(),
param_type: "integer".to_string(),
default: Some(Value::Number(6.0)),
},
TemplateParam {
name: "edges".to_string(),
param_type: "array".to_string(),
default: None,
},
TemplateParam {
name: "weights".to_string(),
param_type: "array".to_string(),
default: None,
},
],
body: r"
param n = {n_vertices};
param edges = {edges};
param weights = {weights};
var x[n] binary;
maximize sum((i,j,w) in zip(edges, weights): w * (x[i] + x[j] - 2*x[i]*x[j]));
"
.to_string(),
},
);
}
pub fn get_function(&self, name: &str) -> Option<&BuiltinFunction> {
self.functions.get(name)
}
pub fn get_pattern(&self, name: &str) -> Option<&Pattern> {
self.patterns.get(name)
}
pub fn get_template(&self, name: &str) -> Option<&Template> {
self.templates.get(name)
}
}
impl Default for StandardLibrary {
fn default() -> Self {
Self::new()
}
}