1use super::ast::{Expression, Value};
4use super::types::{FunctionSignature, VarType};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
9pub struct StandardLibrary {
10 functions: HashMap<String, BuiltinFunction>,
12 patterns: HashMap<String, Pattern>,
14 templates: HashMap<String, Template>,
16}
17
18#[derive(Debug, Clone)]
20pub struct BuiltinFunction {
21 pub name: String,
22 pub signature: FunctionSignature,
23 pub description: String,
24 pub implementation: FunctionImpl,
25}
26
27#[derive(Debug, Clone)]
29pub enum FunctionImpl {
30 Native,
32 DSL { body: Expression },
34}
35
36#[derive(Debug, Clone)]
38pub struct Pattern {
39 pub name: String,
40 pub description: String,
41 pub parameters: Vec<String>,
42 pub expansion: super::ast::AST,
43}
44
45#[derive(Debug, Clone)]
47pub struct Template {
48 pub name: String,
49 pub description: String,
50 pub parameters: Vec<TemplateParam>,
51 pub body: String,
52}
53
54#[derive(Debug, Clone)]
55pub struct TemplateParam {
56 pub name: String,
57 pub param_type: String,
58 pub default: Option<Value>,
59}
60
61impl StandardLibrary {
62 pub fn new() -> Self {
64 let mut stdlib = Self {
65 functions: HashMap::new(),
66 patterns: HashMap::new(),
67 templates: HashMap::new(),
68 };
69
70 stdlib.register_builtin_functions();
71 stdlib.register_common_patterns();
72 stdlib.register_templates();
73
74 stdlib
75 }
76
77 fn register_builtin_functions(&mut self) {
79 self.functions.insert(
81 "abs".to_string(),
82 BuiltinFunction {
83 name: "abs".to_string(),
84 signature: FunctionSignature {
85 param_types: vec![VarType::Continuous],
86 return_type: VarType::Continuous,
87 },
88 description: "Absolute value function".to_string(),
89 implementation: FunctionImpl::Native,
90 },
91 );
92
93 self.functions.insert(
94 "sqrt".to_string(),
95 BuiltinFunction {
96 name: "sqrt".to_string(),
97 signature: FunctionSignature {
98 param_types: vec![VarType::Continuous],
99 return_type: VarType::Continuous,
100 },
101 description: "Square root function".to_string(),
102 implementation: FunctionImpl::Native,
103 },
104 );
105
106 self.functions.insert(
108 "sum".to_string(),
109 BuiltinFunction {
110 name: "sum".to_string(),
111 signature: FunctionSignature {
112 param_types: vec![VarType::Array {
113 element_type: Box::new(VarType::Continuous),
114 dimensions: vec![0],
115 }],
116 return_type: VarType::Continuous,
117 },
118 description: "Sum aggregation function".to_string(),
119 implementation: FunctionImpl::Native,
120 },
121 );
122 }
123
124 fn register_common_patterns(&mut self) {
126 self.patterns.insert(
129 "all_different".to_string(),
130 Pattern {
131 name: "all_different".to_string(),
132 description: "Ensures all variables in a set take different values".to_string(),
133 parameters: vec!["variables".to_string()],
134 expansion: super::ast::AST::Program {
135 declarations: vec![],
136 objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
137 super::ast::Value::Number(0.0),
138 )),
139 constraints: vec![], },
141 },
142 );
143
144 self.patterns.insert(
146 "cardinality".to_string(),
147 Pattern {
148 name: "cardinality".to_string(),
149 description: "Constrains the number of true variables in a set".to_string(),
150 parameters: vec![
151 "variables".to_string(),
152 "min_count".to_string(),
153 "max_count".to_string(),
154 ],
155 expansion: super::ast::AST::Program {
156 declarations: vec![],
157 objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
158 super::ast::Value::Number(0.0),
159 )),
160 constraints: vec![],
161 },
162 },
163 );
164
165 self.patterns.insert(
167 "at_most_one".to_string(),
168 Pattern {
169 name: "at_most_one".to_string(),
170 description: "Ensures at most one variable in a set is true".to_string(),
171 parameters: vec!["variables".to_string()],
172 expansion: super::ast::AST::Program {
173 declarations: vec![],
174 objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
175 super::ast::Value::Number(0.0),
176 )),
177 constraints: vec![],
178 },
179 },
180 );
181
182 self.patterns.insert(
184 "exactly_one".to_string(),
185 Pattern {
186 name: "exactly_one".to_string(),
187 description: "Ensures exactly one variable in a set is true".to_string(),
188 parameters: vec!["variables".to_string()],
189 expansion: super::ast::AST::Program {
190 declarations: vec![],
191 objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
192 super::ast::Value::Number(0.0),
193 )),
194 constraints: vec![],
195 },
196 },
197 );
198 }
199
200 fn register_templates(&mut self) {
202 self.templates.insert(
204 "tsp".to_string(),
205 Template {
206 name: "tsp".to_string(),
207 description: "Traveling Salesman Problem template".to_string(),
208 parameters: vec![
209 TemplateParam {
210 name: "n_cities".to_string(),
211 param_type: "integer".to_string(),
212 default: Some(Value::Number(4.0)),
213 },
214 TemplateParam {
215 name: "distance_matrix".to_string(),
216 param_type: "matrix".to_string(),
217 default: None,
218 },
219 ],
220 body: r"
221 param n = {n_cities};
222 param distances = {distance_matrix};
223
224 var x[n, n] binary;
225
226 minimize sum(i in 0..n, j in 0..n: distances[i][j] * x[i,j]);
227
228 subject to
229 forall(i in 0..n): sum(j in 0..n: x[i,j]) == 1;
230 forall(j in 0..n): sum(i in 0..n: x[i,j]) == 1;
231 "
232 .to_string(),
233 },
234 );
235
236 self.templates.insert(
238 "graph_coloring".to_string(),
239 Template {
240 name: "graph_coloring".to_string(),
241 description: "Graph coloring problem template".to_string(),
242 parameters: vec![
243 TemplateParam {
244 name: "n_vertices".to_string(),
245 param_type: "integer".to_string(),
246 default: Some(Value::Number(5.0)),
247 },
248 TemplateParam {
249 name: "n_colors".to_string(),
250 param_type: "integer".to_string(),
251 default: Some(Value::Number(3.0)),
252 },
253 TemplateParam {
254 name: "edges".to_string(),
255 param_type: "array".to_string(),
256 default: None,
257 },
258 ],
259 body: r"
260 param n_vertices = {n_vertices};
261 param n_colors = {n_colors};
262 param edges = {edges};
263
264 var color[n_vertices, n_colors] binary;
265
266 minimize sum(v in 0..n_vertices, c in 0..n_colors: c * color[v,c]);
267
268 subject to
269 forall(v in 0..n_vertices): sum(c in 0..n_colors: color[v,c]) == 1;
270 forall((u,v) in edges, c in 0..n_colors): color[u,c] + color[v,c] <= 1;
271 "
272 .to_string(),
273 },
274 );
275
276 self.templates.insert(
278 "knapsack".to_string(),
279 Template {
280 name: "knapsack".to_string(),
281 description: "0-1 Knapsack problem template".to_string(),
282 parameters: vec![
283 TemplateParam {
284 name: "n_items".to_string(),
285 param_type: "integer".to_string(),
286 default: Some(Value::Number(10.0)),
287 },
288 TemplateParam {
289 name: "weights".to_string(),
290 param_type: "array".to_string(),
291 default: None,
292 },
293 TemplateParam {
294 name: "values".to_string(),
295 param_type: "array".to_string(),
296 default: None,
297 },
298 TemplateParam {
299 name: "capacity".to_string(),
300 param_type: "number".to_string(),
301 default: Some(Value::Number(100.0)),
302 },
303 ],
304 body: r"
305 param n = {n_items};
306 param weights = {weights};
307 param values = {values};
308 param capacity = {capacity};
309
310 var x[n] binary;
311
312 maximize sum(i in 0..n: values[i] * x[i]);
313
314 subject to
315 sum(i in 0..n: weights[i] * x[i]) <= capacity;
316 "
317 .to_string(),
318 },
319 );
320
321 self.templates.insert(
323 "max_cut".to_string(),
324 Template {
325 name: "max_cut".to_string(),
326 description: "Maximum cut problem template".to_string(),
327 parameters: vec![
328 TemplateParam {
329 name: "n_vertices".to_string(),
330 param_type: "integer".to_string(),
331 default: Some(Value::Number(6.0)),
332 },
333 TemplateParam {
334 name: "edges".to_string(),
335 param_type: "array".to_string(),
336 default: None,
337 },
338 TemplateParam {
339 name: "weights".to_string(),
340 param_type: "array".to_string(),
341 default: None,
342 },
343 ],
344 body: r"
345 param n = {n_vertices};
346 param edges = {edges};
347 param weights = {weights};
348
349 var x[n] binary;
350
351 maximize sum((i,j,w) in zip(edges, weights): w * (x[i] + x[j] - 2*x[i]*x[j]));
352 "
353 .to_string(),
354 },
355 );
356 }
357
358 pub fn get_function(&self, name: &str) -> Option<&BuiltinFunction> {
360 self.functions.get(name)
361 }
362
363 pub fn get_pattern(&self, name: &str) -> Option<&Pattern> {
365 self.patterns.get(name)
366 }
367
368 pub fn get_template(&self, name: &str) -> Option<&Template> {
370 self.templates.get(name)
371 }
372}
373
374impl Default for StandardLibrary {
375 fn default() -> Self {
376 Self::new()
377 }
378}