1use serde::{Deserialize, Serialize};
8use serde_json::Value as JsonValue;
9use sha2::{Digest, Sha256};
10use std::collections::HashMap;
11use thiserror::Error;
12
13mod client;
14pub use client::OatClient;
15
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18#[serde(untagged)]
19pub enum PropertyValue {
20 String(String),
21 Number(i64),
22 Bool(bool),
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
27pub struct Bound {
28 pub lower: i64,
29 pub upper: i64,
30}
31
32impl Bound {
33 pub fn new(lower: i64, upper: i64) -> Self {
35 Bound { lower, upper }
36 }
37
38 pub fn binary() -> Self {
40 Bound { lower: 0, upper: 1 }
41 }
42}
43
44impl Default for Bound {
45 fn default() -> Self {
46 Self::binary()
47 }
48}
49
50impl From<Bound> for [i64; 2] {
51 fn from(bound: Bound) -> Self {
52 [bound.lower, bound.upper]
53 }
54}
55
56impl From<[i64; 2]> for Bound {
57 fn from(arr: [i64; 2]) -> Self {
58 Bound {
59 lower: arr[0],
60 upper: arr[1],
61 }
62 }
63}
64
65#[derive(Error, Debug)]
67pub enum OatError {
68 #[error("Connection error: {0}")]
69 Connection(String),
70
71 #[error("Execution error (status {status}): {message}")]
72 Execution {
73 status: u16,
74 message: String,
75 response: JsonValue,
76 },
77
78 #[error("Serialization error: {0}")]
79 Serialization(#[from] serde_json::Error),
80
81 #[error("HTTP error: {0}")]
82 Http(#[from] reqwest::Error),
83}
84
85pub type Result<T> = std::result::Result<T, OatError>;
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90#[serde(untagged)]
91pub enum Arg<T> {
92 Ref { #[serde(rename = "$ref")] id: String },
93 Value(T),
94}
95
96impl<T> Arg<T> {
97 pub fn from_call(call: &FunctionCall) -> Self {
98 Arg::Ref { id: call.out.clone() }
99 }
100
101 pub fn from_value(value: T) -> Self {
102 Arg::Value(value)
103 }
104}
105
106impl<T> From<&FunctionCall> for Arg<T> {
108 fn from(call: &FunctionCall) -> Self {
109 Arg::Ref { id: call.out.clone() }
110 }
111}
112
113impl From<&str> for Arg<String> {
114 fn from(s: &str) -> Self {
115 Arg::Value(s.to_string())
116 }
117}
118
119impl From<String> for Arg<String> {
120 fn from(s: String) -> Self {
121 Arg::Value(s)
122 }
123}
124
125impl From<i64> for Arg<i64> {
126 fn from(val: i64) -> Self {
127 Arg::Value(val)
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct FunctionCall {
134 pub fn_name: String,
136 pub args: JsonValue,
138 pub out: String,
140}
141
142impl FunctionCall {
143 pub fn new(fn_name: impl Into<String>, args: JsonValue) -> Self {
145 let fn_name = fn_name.into();
146 let out = Self::compute_hash(&fn_name, &args);
147
148 Self { fn_name, args, out }
149 }
150
151 fn serialize_canonical(value: &JsonValue) -> String {
153 match value {
154 JsonValue::Object(map) => {
155 let mut pairs: Vec<_> = map.iter().collect();
157 pairs.sort_by_key(|(k, _)| *k);
158
159 let sorted = pairs
160 .iter()
161 .map(|(k, v)| format!(r#""{}": {}"#, k, Self::serialize_canonical(v)))
162 .collect::<Vec<_>>()
163 .join(", ");
164
165 format!("{{{}}}", sorted)
166 }
167 JsonValue::Array(arr) => {
168 let items = arr
169 .iter()
170 .map(|v| Self::serialize_canonical(v))
171 .collect::<Vec<_>>()
172 .join(", ");
173 format!("[{}]", items)
174 }
175 JsonValue::String(s) => {
176 serde_json::to_string(s).unwrap()
178 }
179 JsonValue::Number(n) => n.to_string(),
180 JsonValue::Bool(b) => b.to_string(),
181 JsonValue::Null => "null".to_string(),
182 }
183 }
184
185 fn compute_hash(fn_name: &str, args: &JsonValue) -> String {
187 let mut hasher = Sha256::new();
188
189 let hash_input = serde_json::json!({
191 "fn": fn_name,
192 "args": args
193 });
194
195 let serialized = Self::serialize_canonical(&hash_input);
197
198 hasher.update(serialized.as_bytes());
199 hex::encode(hasher.finalize())
200 }
201
202 pub fn to_json(&self) -> JsonValue {
204 serde_json::json!({
205 "fn": self.fn_name,
206 "args": self.args,
207 "out": self.out
208 })
209 }
210}
211
212impl PartialEq for FunctionCall {
213 fn eq(&self, other: &Self) -> bool {
214 self.out == other.out
215 }
216}
217
218impl Eq for FunctionCall {}
219
220impl std::hash::Hash for FunctionCall {
221 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
222 self.out.hash(state);
223 }
224}
225
226pub fn set_primitive(id: impl Into<String>, bound: Bound) -> FunctionCall {
232 FunctionCall::new(
233 "set_primitive",
234 serde_json::json!({
235 "id": id.into(),
236 "bound": [bound.lower, bound.upper]
237 }),
238 )
239}
240
241pub fn set_primitives(ids: Vec<String>, bound: Bound) -> FunctionCall {
243 FunctionCall::new(
244 "set_primitives",
245 serde_json::json!({
246 "ids": ids,
247 "bound": [bound.lower, bound.upper]
248 }),
249 )
250}
251
252pub fn set_property(
254 id: impl Into<Arg<String>>,
255 property: impl Into<String>,
256 value: impl Serialize,
257) -> FunctionCall {
258 FunctionCall::new(
259 "set_property",
260 serde_json::json!({
261 "id": id.into(),
262 "property": property.into(),
263 "value": value
264 }),
265 )
266}
267
268pub fn set_properties<V: Serialize>(
270 id: impl Into<Arg<String>>,
271 properties: Vec<(impl Into<Arg<String>>, V)>,
272) -> FunctionCall {
273 let props: Vec<_> = properties
274 .into_iter()
275 .map(|(k, v)| (k.into(), v))
276 .collect();
277 FunctionCall::new(
278 "set_properties",
279 serde_json::json!({
280 "id": id.into(),
281 "properties": props
282 }),
283 )
284}
285
286pub fn set_and(references: Vec<Arg<String>>, alias: Option<String>) -> FunctionCall {
288 let mut args = serde_json::json!({ "references": references });
289 if let Some(a) = alias {
290 args["alias"] = serde_json::json!(a);
291 }
292 FunctionCall::new("set_and", args)
293}
294
295pub fn set_or(references: Vec<Arg<String>>, alias: Option<String>) -> FunctionCall {
297 let mut args = serde_json::json!({ "references": references });
298 if let Some(a) = alias {
299 args["alias"] = serde_json::json!(a);
300 }
301 FunctionCall::new("set_or", args)
302}
303
304pub fn set_not(references: Vec<Arg<String>>, alias: Option<String>) -> FunctionCall {
306 let mut args = serde_json::json!({ "references": references });
307 if let Some(a) = alias {
308 args["alias"] = serde_json::json!(a);
309 }
310 FunctionCall::new("set_not", args)
311}
312
313pub fn set_xor(references: Vec<Arg<String>>, alias: Option<String>) -> FunctionCall {
315 let mut args = serde_json::json!({ "references": references });
316 if let Some(a) = alias {
317 args["alias"] = serde_json::json!(a);
318 }
319 FunctionCall::new("set_xor", args)
320}
321
322pub fn set_imply(lhs: Arg<String>, rhs: Arg<String>, alias: Option<String>) -> FunctionCall {
324 let mut args = serde_json::json!({ "lhs": lhs, "rhs": rhs });
325 if let Some(a) = alias {
326 args["alias"] = serde_json::json!(a);
327 }
328 FunctionCall::new("set_imply", args)
329}
330
331pub fn set_equiv(lhs: Arg<String>, rhs: Arg<String>, alias: Option<String>) -> FunctionCall {
333 let mut args = serde_json::json!({ "lhs": lhs, "rhs": rhs });
334 if let Some(a) = alias {
335 args["alias"] = serde_json::json!(a);
336 }
337 FunctionCall::new("set_equiv", args)
338}
339
340pub fn set_atleast(references: Vec<Arg<String>>, value: i64, alias: Option<String>) -> FunctionCall {
342 let mut args = serde_json::json!({ "references": references, "value": value });
343 if let Some(a) = alias {
344 args["alias"] = serde_json::json!(a);
345 }
346 FunctionCall::new("set_atleast", args)
347}
348
349pub fn set_atmost(references: Vec<Arg<String>>, value: i64, alias: Option<String>) -> FunctionCall {
351 let mut args = serde_json::json!({ "references": references, "value": value });
352 if let Some(a) = alias {
353 args["alias"] = serde_json::json!(a);
354 }
355 FunctionCall::new("set_atmost", args)
356}
357
358pub fn set_equal(references: Vec<Arg<String>>, value: i64, alias: Option<String>) -> FunctionCall {
360 let mut args = serde_json::json!({ "references": references, "value": value });
361 if let Some(a) = alias {
362 args["alias"] = serde_json::json!(a);
363 }
364 FunctionCall::new("set_equal", args)
365}
366
367#[derive(Debug, Clone, Serialize)]
368pub struct Coefficient {
369 pub id: Arg<String>,
370 pub coefficient: Arg<i64>,
371}
372
373pub fn set_gelineq(coefficients: Vec<Coefficient>, bias: i64, alias: Option<String>) -> FunctionCall {
375 let mut args = serde_json::json!({ "coefficients": coefficients, "bias": bias });
376 if let Some(a) = alias {
377 args["alias"] = serde_json::json!(a);
378 }
379 FunctionCall::new("set_gelineq", args)
380}
381
382pub fn sub(root: Arg<String>) -> FunctionCall {
384 FunctionCall::new("sub", serde_json::json!({ "root": root }))
385}
386
387pub fn sub_many(roots: Vec<Arg<String>>) -> FunctionCall {
389 FunctionCall::new("sub_many", serde_json::json!({ "roots": roots }))
390}
391
392#[derive(Debug, Clone, Serialize)]
393pub struct Assignment {
394 pub id: Arg<String>,
395 pub bound: [i64; 2],
396}
397
398pub fn propagate(assignments: Vec<Assignment>) -> FunctionCall {
400 FunctionCall::new("propagate", serde_json::json!({ "assignments": assignments }))
401}
402
403pub fn propagate_with_default(
408 assignments: Vec<Assignment>,
409 default_bound: Option<Bound>,
410) -> FunctionCall {
411 let mut args = serde_json::json!({ "assignments": assignments });
412 if let Some(b) = default_bound {
413 args["default_bound"] = serde_json::json!(b);
414 }
415 FunctionCall::new("propagate", args)
416}
417
418pub fn propagate_many(many_assignments: Vec<Vec<Assignment>>) -> FunctionCall {
420 FunctionCall::new(
421 "propagate_many",
422 serde_json::json!({ "many_assignments": many_assignments }),
423 )
424}
425
426pub fn propagate_many_with_default(
430 many_assignments: Vec<Vec<Assignment>>,
431 default_bound: Option<Bound>,
432) -> FunctionCall {
433 let mut args = serde_json::json!({ "many_assignments": many_assignments });
434 if let Some(b) = default_bound {
435 args["default_bound"] = serde_json::json!(b);
436 }
437 FunctionCall::new("propagate_many", args)
438}
439
440pub fn solve(
442 dag: Arg<HashMap<String, JsonValue>>,
443 objective: Vec<Coefficient>,
444 assume: Vec<Assignment>,
445 maximize: bool,
446) -> FunctionCall {
447 FunctionCall::new(
448 "solve",
449 serde_json::json!({
450 "dag": dag,
451 "objective": objective,
452 "assume": assume,
453 "maximize": maximize
454 }),
455 )
456}
457
458pub fn solve_many(
460 dag: Arg<HashMap<String, JsonValue>>,
461 objectives: Vec<Vec<Coefficient>>,
462 assume: Vec<Assignment>,
463 maximize: bool,
464) -> FunctionCall {
465 FunctionCall::new(
466 "solve_many",
467 serde_json::json!({
468 "dag": dag,
469 "objectives": objectives,
470 "assume": assume,
471 "maximize": maximize
472 }),
473 )
474}
475
476pub fn get_node(id: Arg<String>) -> FunctionCall {
478 FunctionCall::new("get_node", serde_json::json!({ "id": id }))
479}
480
481pub fn get_nodes(ids: Vec<Arg<String>>) -> FunctionCall {
483 FunctionCall::new("get_nodes", serde_json::json!({ "ids": ids }))
484}
485
486pub fn get_many_nodes(many_ids: Arg<Vec<Vec<String>>>) -> FunctionCall {
489 FunctionCall::new(
490 "get_many_nodes",
491 serde_json::json!({ "many_ids": many_ids }),
492 )
493}
494
495pub fn get_node_ids(filter: Option<JsonValue>) -> FunctionCall {
497 let args = if let Some(f) = filter {
498 serde_json::json!({ "filter": f })
499 } else {
500 serde_json::json!({})
501 };
502 FunctionCall::new("get_node_ids", args)
503}
504
505pub fn get_ids_from_dag(dag: Arg<HashMap<String, JsonValue>>) -> FunctionCall {
507 FunctionCall::new("get_ids_from_dag", serde_json::json!({ "dag": dag }))
508}
509
510pub fn get_tighten_dag(
512 dag: Arg<HashMap<String, JsonValue>>,
513 assumptions: Vec<Assignment>,
514) -> FunctionCall {
515 FunctionCall::new(
516 "get_tighten_dag",
517 serde_json::json!({ "dag": dag, "assumptions": assumptions }),
518 )
519}
520
521pub fn get_polyhedron_from_dag(dag: Arg<HashMap<String, JsonValue>>) -> FunctionCall {
523 FunctionCall::new("get_polyhedron_from_dag", serde_json::json!({ "dag": dag }))
524}
525
526pub fn get_roots_from_dag(dag: Arg<HashMap<String, JsonValue>>) -> FunctionCall {
528 FunctionCall::new("get_roots_from_dag", serde_json::json!({ "dag": dag }))
529}
530
531pub fn get_primitive_ids_from_dag(dag: Arg<HashMap<String, JsonValue>>) -> FunctionCall {
533 FunctionCall::new("get_primitive_ids_from_dag", serde_json::json!({ "dag": dag }))
534}
535
536pub fn get_composite_ids_from_dag(dag: Arg<HashMap<String, JsonValue>>) -> FunctionCall {
538 FunctionCall::new("get_composite_ids_from_dag", serde_json::json!({ "dag": dag }))
539}
540
541pub fn get_ids_from_assignments(assignments: Vec<Assignment>) -> FunctionCall {
543 FunctionCall::new("get_ids_from_assignments", serde_json::json!({ "assignments": assignments }))
544}
545
546pub fn get_ids_from_many_assignments(
549 many_assignments: Arg<Vec<Vec<Assignment>>>,
550) -> FunctionCall {
551 FunctionCall::new(
552 "get_ids_from_many_assignments",
553 serde_json::json!({ "many_assignments": many_assignments }),
554 )
555}
556
557pub fn filter_dag(dag: Arg<HashMap<String, JsonValue>>, filter: JsonValue) -> FunctionCall {
559 FunctionCall::new("filter_dag", serde_json::json!({ "dag": dag, "filter": filter }))
560}
561
562pub fn filter_many_ids(
567 many_ids: Arg<Vec<Vec<String>>>,
568 filter: JsonValue,
569) -> FunctionCall {
570 FunctionCall::new(
571 "filter_many_ids",
572 serde_json::json!({ "many_ids": many_ids, "filter": filter }),
573 )
574}
575
576pub fn filter_assignments(
578 assignments: Vec<Assignment>,
579 lower_leq: Arg<i64>,
580 upper_geq: Arg<i64>,
581) -> FunctionCall {
582 FunctionCall::new(
583 "filter_assignments",
584 serde_json::json!({
585 "assignments": assignments,
586 "lower_leq": lower_leq,
587 "upper_geq": upper_geq
588 }),
589 )
590}
591
592pub fn filter_many_assignments(
595 many_assignments: Arg<Vec<Vec<Assignment>>>,
596 lower_leq: Arg<i64>,
597 upper_geq: Arg<i64>,
598) -> FunctionCall {
599 FunctionCall::new(
600 "filter_many_assignments",
601 serde_json::json!({
602 "many_assignments": many_assignments,
603 "lower_leq": lower_leq,
604 "upper_geq": upper_geq
605 }),
606 )
607}
608
609pub fn get_property_values(property: impl Into<String>) -> FunctionCall {
611 FunctionCall::new(
612 "get_property_values",
613 serde_json::json!({ "property": property.into() }),
614 )
615}
616
617pub fn get_alias(id: Arg<String>) -> FunctionCall {
619 FunctionCall::new("get_alias", serde_json::json!({ "id": id }))
620}
621
622pub fn get_id_from_alias(alias: impl Into<String>) -> FunctionCall {
624 FunctionCall::new(
625 "get_id_from_alias",
626 serde_json::json!({ "alias": alias.into() }),
627 )
628}
629
630pub fn get_aliases_from_id(id: Arg<String>) -> FunctionCall {
632 FunctionCall::new("get_aliases_from_id", serde_json::json!({ "id": id }))
633}
634
635pub fn get_ids_from_aliases(aliases: Vec<String>) -> FunctionCall {
637 FunctionCall::new("get_ids_from_aliases", serde_json::json!({ "aliases": aliases }))
638}
639
640pub fn get_node_children(id: Arg<String>) -> FunctionCall {
642 FunctionCall::new("get_node_children", serde_json::json!({ "id": id }))
643}
644
645pub fn get_node_parents(id: Arg<String>) -> FunctionCall {
647 FunctionCall::new("get_node_parents", serde_json::json!({ "id": id }))
648}
649
650pub fn validate(dag: Arg<HashMap<String, JsonValue>>) -> FunctionCall {
652 FunctionCall::new("validate", serde_json::json!({ "dag": dag }))
653}
654
655pub fn ranks(dag: Arg<HashMap<String, JsonValue>>) -> FunctionCall {
657 FunctionCall::new("ranks", serde_json::json!({ "dag": dag }))
658}
659
660pub fn delete_node(id: Arg<String>) -> FunctionCall {
662 FunctionCall::new("delete_node", serde_json::json!({ "id": id }))
663}
664
665pub fn delete_sub(roots: Vec<Arg<String>>) -> FunctionCall {
667 FunctionCall::new("delete_sub", serde_json::json!({ "roots": roots }))
668}
669
670#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
672#[serde(rename_all = "snake_case")]
673pub enum MergeObjectiveStrategy {
674 Sum,
676 ReplaceFromRight,
678}
679
680impl Default for MergeObjectiveStrategy {
681 fn default() -> Self {
682 MergeObjectiveStrategy::ReplaceFromRight
683 }
684}
685
686pub fn as_id(id: Arg<String>) -> FunctionCall {
688 FunctionCall::new("as_id", serde_json::json!({ "id": id }))
689}
690
691pub fn count_ids(ids: Vec<Arg<String>>) -> FunctionCall {
693 FunctionCall::new("count_ids", serde_json::json!({ "ids": ids }))
694}
695
696pub fn objective_from_constant(ids: Vec<Arg<String>>, value: Arg<i64>) -> FunctionCall {
698 FunctionCall::new(
699 "objective_from_constant",
700 serde_json::json!({ "ids": ids, "value": value }),
701 )
702}
703
704pub fn objective_from_doubling_weights(start: Arg<i64>, ids: Vec<Arg<String>>) -> FunctionCall {
706 FunctionCall::new(
707 "objective_from_doubling_weights",
708 serde_json::json!({ "start": start, "ids": ids }),
709 )
710}
711
712pub fn merge_objectives(
714 obj1: Arg<Vec<HashMap<String, i64>>>,
715 obj2: Arg<Vec<HashMap<String, i64>>>,
716 strategy: Option<MergeObjectiveStrategy>,
717) -> FunctionCall {
718 let mut args = serde_json::json!({ "obj1": obj1, "obj2": obj2 });
719 if let Some(s) = strategy {
720 args["strategy"] = serde_json::to_value(s).unwrap();
721 }
722 FunctionCall::new("merge_objectives", args)
723}
724
725#[cfg(test)]
726mod tests {
727 use super::*;
728
729 #[test]
730 fn test_canonical_serialization() {
731 let json = serde_json::json!({
733 "z": 3,
734 "a": 1,
735 "m": 2
736 });
737
738 let canonical = FunctionCall::serialize_canonical(&json);
739 assert_eq!(canonical, r#"{"a": 1, "m": 2, "z": 3}"#);
740 }
741
742 #[test]
743 fn test_hash_determinism() {
744 let call1 = set_primitive("x", Bound::binary());
746 let call2 = set_primitive("x", Bound::new(0, 1));
747
748 assert_eq!(call1.out, call2.out);
749 }
750
751 #[test]
752 fn test_hash_includes_function_name() {
753 let call1 = FunctionCall::new(
755 "function_a",
756 serde_json::json!({"id": "test"})
757 );
758 let call2 = FunctionCall::new(
759 "function_b",
760 serde_json::json!({"id": "test"})
761 );
762
763 assert_ne!(call1.out, call2.out);
764 }
765
766 #[test]
767 fn test_nested_object_sorting() {
768 let json = serde_json::json!({
770 "outer_z": {
771 "inner_z": 3,
772 "inner_a": 1
773 },
774 "outer_a": 2
775 });
776
777 let canonical = FunctionCall::serialize_canonical(&json);
778 assert!(canonical.contains(r#""inner_a": 1, "inner_z": 3"#));
779 assert!(canonical.starts_with(r#"{"outer_a":"#));
780 }
781
782 #[test]
783 fn test_reference_in_args() {
784 let a = set_primitive("a", Bound::binary());
786 let b = set_primitive("b", Bound::binary());
787
788 let and_constraint = set_and(
790 vec![(&a).into(), (&b).into()],
791 Some("test_and".to_string())
792 );
793
794 assert!(!and_constraint.out.is_empty());
796 assert_eq!(and_constraint.out.len(), 64); }
798}