1use std::collections::HashSet;
2
3use tx3_tir::compile::{CompiledTx, Compiler};
4use tx3_tir::encoding::AnyTir;
5use tx3_tir::model::v1beta0 as tir;
6use tx3_tir::reduce::{Apply as _, ArgMap};
7use tx3_tir::Node as _;
8
9use crate::inputs::CanonicalQuery;
10
11pub mod inputs;
12pub mod interop;
13pub mod trp;
14
15pub use tx3_tir::model::assets::CanonicalAssets;
16pub use tx3_tir::model::core::{Type, Utxo, UtxoRef, UtxoSet};
17
18pub use tx3_tir::model::v1beta0::{Expression, StructExpr};
20
21#[derive(Debug, thiserror::Error)]
22pub enum Error {
23 #[error("can't compile non-constant tir")]
24 CantCompileNonConstantTir,
25
26 #[error(transparent)]
27 CompileError(#[from] tx3_tir::compile::Error),
28
29 #[error(transparent)]
30 InteropError(#[from] interop::Error),
31
32 #[error(transparent)]
33 ReduceError(#[from] tx3_tir::reduce::Error),
34
35 #[error("expected {0}, got {1:?}")]
36 ExpectedData(String, tir::Expression),
37
38 #[error("input query too broad")]
39 InputQueryTooBroad,
40
41 #[error("input not resolved: {0}")]
42 InputNotResolved(String, CanonicalQuery, Vec<UtxoRef>),
43
44 #[error("missing argument `{key}` of type {ty:?}")]
45 MissingTxArg {
46 key: String,
47 ty: tx3_tir::model::core::Type,
48 },
49
50 #[error("transient error: {0}")]
51 TransientError(String),
52
53 #[error("store error: {0}")]
54 StoreError(String),
55
56 #[error("TIR encode / decode error: {0}")]
57 TirEncodingError(#[from] tx3_tir::encoding::Error),
58
59 #[error("tx was not accepted: {0}")]
60 TxNotAccepted(String),
61
62 #[error("tx script returned failure")]
63 TxScriptFailure(Vec<String>),
64}
65
66pub enum UtxoPattern<'a> {
67 ByAddress(&'a [u8]),
68 ByAssetPolicy(&'a [u8]),
69 ByAsset(&'a [u8], &'a [u8]),
70}
71
72impl<'a> UtxoPattern<'a> {
73 pub fn by_address(address: &'a [u8]) -> Self {
74 Self::ByAddress(address)
75 }
76
77 pub fn by_asset_policy(policy: &'a [u8]) -> Self {
78 Self::ByAssetPolicy(policy)
79 }
80
81 pub fn by_asset(policy: &'a [u8], name: &'a [u8]) -> Self {
82 Self::ByAsset(policy, name)
83 }
84}
85
86#[trait_variant::make(Send)]
87pub trait UtxoStore {
88 async fn narrow_refs(&self, pattern: UtxoPattern<'_>) -> Result<HashSet<UtxoRef>, Error>;
89 async fn fetch_utxos(&self, refs: HashSet<UtxoRef>) -> Result<UtxoSet, Error>;
90}
91
92async fn eval_pass<C, S>(
93 tx: &AnyTir,
94 compiler: &mut C,
95 utxos: &S,
96 last_eval: Option<&CompiledTx>,
97) -> Result<Option<CompiledTx>, Error>
98where
99 C: Compiler<Expression = tir::Expression, CompilerOp = tir::CompilerOp>,
100 S: UtxoStore,
101{
102 let attempt = tx.clone();
103
104 let fees = last_eval.as_ref().map(|e| e.fee).unwrap_or(0);
105
106 let attempt = tx3_tir::reduce::apply_fees(attempt, fees)?;
107
108 let attempt = attempt.apply(compiler)?;
109
110 let attempt = tx3_tir::reduce::reduce(attempt)?;
111
112 let attempt = crate::inputs::resolve(attempt, utxos).await?;
113
114 let attempt = tx3_tir::reduce::reduce(attempt)?;
115
116 if !attempt.is_constant() {
117 return Err(Error::CantCompileNonConstantTir);
118 }
119
120 let eval = compiler.compile(&attempt)?;
121
122 let Some(last_eval) = last_eval else {
123 return Ok(Some(eval));
124 };
125
126 if eval != *last_eval {
127 return Ok(Some(eval));
128 }
129
130 Ok(None)
131}
132
133fn safe_apply_args(tir: AnyTir, args: &ArgMap) -> Result<AnyTir, Error> {
134 let params = tx3_tir::reduce::find_params(&tir);
135
136 for (key, ty) in params.iter() {
138 if !args.contains_key(key) {
139 return Err(Error::MissingTxArg {
140 key: key.to_string(),
141 ty: ty.clone(),
142 });
143 };
144 }
145
146 let tir = tx3_tir::reduce::apply_args(tir, args)?;
147
148 Ok(tir)
149}
150
151pub async fn resolve_tx<C, S>(
152 tx: AnyTir,
153 args: &ArgMap,
154 compiler: &mut C,
155 utxos: &S,
156 max_optimize_rounds: usize,
157) -> Result<CompiledTx, Error>
158where
159 C: Compiler<Expression = tir::Expression, CompilerOp = tir::CompilerOp>,
160 S: UtxoStore,
161{
162 let tx = safe_apply_args(tx, args)?;
163
164 let max_optimize_rounds = max_optimize_rounds.max(3);
165
166 let mut last_eval = None;
167 let mut rounds = 0;
168
169 while let Some(better) = eval_pass(&tx, compiler, utxos, last_eval.as_ref()).await? {
170 last_eval = Some(better);
171
172 if rounds > max_optimize_rounds {
173 break;
174 }
175
176 rounds += 1;
177 }
178
179 Ok(last_eval.unwrap())
180}