quantrs2_tytan/problem_dsl/
stdlib.rs

1//! Standard library for the problem DSL.
2
3use super::ast::{Expression, Value};
4use super::types::{FunctionSignature, VarType};
5use std::collections::HashMap;
6
7/// Standard library
8#[derive(Debug, Clone)]
9pub struct StandardLibrary {
10    /// Built-in functions
11    functions: HashMap<String, BuiltinFunction>,
12    /// Common patterns
13    patterns: HashMap<String, Pattern>,
14    /// Problem templates
15    templates: HashMap<String, Template>,
16}
17
18/// Built-in function
19#[derive(Debug, Clone)]
20pub struct BuiltinFunction {
21    pub name: String,
22    pub signature: FunctionSignature,
23    pub description: String,
24    pub implementation: FunctionImpl,
25}
26
27/// Function implementation
28#[derive(Debug, Clone)]
29pub enum FunctionImpl {
30    /// Native Rust implementation
31    Native,
32    /// DSL implementation
33    DSL { body: Expression },
34}
35
36/// Common pattern
37#[derive(Debug, Clone)]
38pub struct Pattern {
39    pub name: String,
40    pub description: String,
41    pub parameters: Vec<String>,
42    pub expansion: super::ast::AST,
43}
44
45/// Problem template
46#[derive(Debug, Clone)]
47pub struct Template {
48    pub name: String,
49    pub description: String,
50    pub parameters: Vec<TemplateParam>,
51    pub body: String,
52}
53
54#[derive(Debug, Clone)]
55pub struct TemplateParam {
56    pub name: String,
57    pub param_type: String,
58    pub default: Option<Value>,
59}
60
61impl StandardLibrary {
62    /// Create a new standard library
63    pub fn new() -> Self {
64        let mut stdlib = Self {
65            functions: HashMap::new(),
66            patterns: HashMap::new(),
67            templates: HashMap::new(),
68        };
69
70        stdlib.register_builtin_functions();
71        stdlib.register_common_patterns();
72        stdlib.register_templates();
73
74        stdlib
75    }
76
77    /// Register built-in functions
78    fn register_builtin_functions(&mut self) {
79        // Mathematical functions
80        self.functions.insert(
81            "abs".to_string(),
82            BuiltinFunction {
83                name: "abs".to_string(),
84                signature: FunctionSignature {
85                    param_types: vec![VarType::Continuous],
86                    return_type: VarType::Continuous,
87                },
88                description: "Absolute value function".to_string(),
89                implementation: FunctionImpl::Native,
90            },
91        );
92
93        self.functions.insert(
94            "sqrt".to_string(),
95            BuiltinFunction {
96                name: "sqrt".to_string(),
97                signature: FunctionSignature {
98                    param_types: vec![VarType::Continuous],
99                    return_type: VarType::Continuous,
100                },
101                description: "Square root function".to_string(),
102                implementation: FunctionImpl::Native,
103            },
104        );
105
106        // Aggregation functions
107        self.functions.insert(
108            "sum".to_string(),
109            BuiltinFunction {
110                name: "sum".to_string(),
111                signature: FunctionSignature {
112                    param_types: vec![VarType::Array {
113                        element_type: Box::new(VarType::Continuous),
114                        dimensions: vec![0],
115                    }],
116                    return_type: VarType::Continuous,
117                },
118                description: "Sum aggregation function".to_string(),
119                implementation: FunctionImpl::Native,
120            },
121        );
122    }
123
124    /// Register common patterns
125    fn register_common_patterns(&mut self) {
126        // All-different constraint pattern
127        // Pattern for ensuring variables take different values
128        self.patterns.insert(
129            "all_different".to_string(),
130            Pattern {
131                name: "all_different".to_string(),
132                description: "Ensures all variables in a set take different values".to_string(),
133                parameters: vec!["variables".to_string()],
134                expansion: super::ast::AST::Program {
135                    declarations: vec![],
136                    objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
137                        super::ast::Value::Number(0.0),
138                    )),
139                    constraints: vec![], // Would be filled with actual constraints during expansion
140                },
141            },
142        );
143
144        // Cardinality constraint pattern
145        self.patterns.insert(
146            "cardinality".to_string(),
147            Pattern {
148                name: "cardinality".to_string(),
149                description: "Constrains the number of true variables in a set".to_string(),
150                parameters: vec![
151                    "variables".to_string(),
152                    "min_count".to_string(),
153                    "max_count".to_string(),
154                ],
155                expansion: super::ast::AST::Program {
156                    declarations: vec![],
157                    objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
158                        super::ast::Value::Number(0.0),
159                    )),
160                    constraints: vec![],
161                },
162            },
163        );
164
165        // At-most-one constraint pattern
166        self.patterns.insert(
167            "at_most_one".to_string(),
168            Pattern {
169                name: "at_most_one".to_string(),
170                description: "Ensures at most one variable in a set is true".to_string(),
171                parameters: vec!["variables".to_string()],
172                expansion: super::ast::AST::Program {
173                    declarations: vec![],
174                    objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
175                        super::ast::Value::Number(0.0),
176                    )),
177                    constraints: vec![],
178                },
179            },
180        );
181
182        // Exactly-one constraint pattern
183        self.patterns.insert(
184            "exactly_one".to_string(),
185            Pattern {
186                name: "exactly_one".to_string(),
187                description: "Ensures exactly one variable in a set is true".to_string(),
188                parameters: vec!["variables".to_string()],
189                expansion: super::ast::AST::Program {
190                    declarations: vec![],
191                    objective: super::ast::Objective::Minimize(super::ast::Expression::Literal(
192                        super::ast::Value::Number(0.0),
193                    )),
194                    constraints: vec![],
195                },
196            },
197        );
198    }
199
200    /// Register problem templates
201    fn register_templates(&mut self) {
202        // Traveling Salesman Problem template
203        self.templates.insert(
204            "tsp".to_string(),
205            Template {
206                name: "tsp".to_string(),
207                description: "Traveling Salesman Problem template".to_string(),
208                parameters: vec![
209                    TemplateParam {
210                        name: "n_cities".to_string(),
211                        param_type: "integer".to_string(),
212                        default: Some(Value::Number(4.0)),
213                    },
214                    TemplateParam {
215                        name: "distance_matrix".to_string(),
216                        param_type: "matrix".to_string(),
217                        default: None,
218                    },
219                ],
220                body: r"
221                    param n = {n_cities};
222                    param distances = {distance_matrix};
223
224                    var x[n, n] binary;
225
226                    minimize sum(i in 0..n, j in 0..n: distances[i][j] * x[i,j]);
227
228                    subject to
229                        forall(i in 0..n): sum(j in 0..n: x[i,j]) == 1;
230                        forall(j in 0..n): sum(i in 0..n: x[i,j]) == 1;
231                "
232                .to_string(),
233            },
234        );
235
236        // Graph Coloring template
237        self.templates.insert(
238            "graph_coloring".to_string(),
239            Template {
240                name: "graph_coloring".to_string(),
241                description: "Graph coloring problem template".to_string(),
242                parameters: vec![
243                    TemplateParam {
244                        name: "n_vertices".to_string(),
245                        param_type: "integer".to_string(),
246                        default: Some(Value::Number(5.0)),
247                    },
248                    TemplateParam {
249                        name: "n_colors".to_string(),
250                        param_type: "integer".to_string(),
251                        default: Some(Value::Number(3.0)),
252                    },
253                    TemplateParam {
254                        name: "edges".to_string(),
255                        param_type: "array".to_string(),
256                        default: None,
257                    },
258                ],
259                body: r"
260                    param n_vertices = {n_vertices};
261                    param n_colors = {n_colors};
262                    param edges = {edges};
263
264                    var color[n_vertices, n_colors] binary;
265
266                    minimize sum(v in 0..n_vertices, c in 0..n_colors: c * color[v,c]);
267
268                    subject to
269                        forall(v in 0..n_vertices): sum(c in 0..n_colors: color[v,c]) == 1;
270                        forall((u,v) in edges, c in 0..n_colors): color[u,c] + color[v,c] <= 1;
271                "
272                .to_string(),
273            },
274        );
275
276        // Knapsack Problem template
277        self.templates.insert(
278            "knapsack".to_string(),
279            Template {
280                name: "knapsack".to_string(),
281                description: "0-1 Knapsack problem template".to_string(),
282                parameters: vec![
283                    TemplateParam {
284                        name: "n_items".to_string(),
285                        param_type: "integer".to_string(),
286                        default: Some(Value::Number(10.0)),
287                    },
288                    TemplateParam {
289                        name: "weights".to_string(),
290                        param_type: "array".to_string(),
291                        default: None,
292                    },
293                    TemplateParam {
294                        name: "values".to_string(),
295                        param_type: "array".to_string(),
296                        default: None,
297                    },
298                    TemplateParam {
299                        name: "capacity".to_string(),
300                        param_type: "number".to_string(),
301                        default: Some(Value::Number(100.0)),
302                    },
303                ],
304                body: r"
305                    param n = {n_items};
306                    param weights = {weights};
307                    param values = {values};
308                    param capacity = {capacity};
309
310                    var x[n] binary;
311
312                    maximize sum(i in 0..n: values[i] * x[i]);
313
314                    subject to
315                        sum(i in 0..n: weights[i] * x[i]) <= capacity;
316                "
317                .to_string(),
318            },
319        );
320
321        // Maximum Cut template
322        self.templates.insert(
323            "max_cut".to_string(),
324            Template {
325                name: "max_cut".to_string(),
326                description: "Maximum cut problem template".to_string(),
327                parameters: vec![
328                    TemplateParam {
329                        name: "n_vertices".to_string(),
330                        param_type: "integer".to_string(),
331                        default: Some(Value::Number(6.0)),
332                    },
333                    TemplateParam {
334                        name: "edges".to_string(),
335                        param_type: "array".to_string(),
336                        default: None,
337                    },
338                    TemplateParam {
339                        name: "weights".to_string(),
340                        param_type: "array".to_string(),
341                        default: None,
342                    },
343                ],
344                body: r"
345                    param n = {n_vertices};
346                    param edges = {edges};
347                    param weights = {weights};
348
349                    var x[n] binary;
350
351                    maximize sum((i,j,w) in zip(edges, weights): w * (x[i] + x[j] - 2*x[i]*x[j]));
352                "
353                .to_string(),
354            },
355        );
356    }
357
358    /// Get function by name
359    pub fn get_function(&self, name: &str) -> Option<&BuiltinFunction> {
360        self.functions.get(name)
361    }
362
363    /// Get pattern by name
364    pub fn get_pattern(&self, name: &str) -> Option<&Pattern> {
365        self.patterns.get(name)
366    }
367
368    /// Get template by name
369    pub fn get_template(&self, name: &str) -> Option<&Template> {
370        self.templates.get(name)
371    }
372}
373
374impl Default for StandardLibrary {
375    fn default() -> Self {
376        Self::new()
377    }
378}