Skip to main content

ark_fees/
lib.rs

1//! Fee estimation library using CEL (Common Expression Language) for calculating Arkade transaction
2//! fees.
3//!
4//! This library provides an `Estimator` that evaluates CEL expressions to calculate fees
5//! based on input and output characteristics.
6
7use cel::Context;
8use cel::Program;
9use std::time::SystemTime;
10use std::time::UNIX_EPOCH;
11
12/// Fee amount as a floating point value in satoshis.
13#[derive(Debug, Clone, Copy, PartialEq, Default)]
14pub struct FeeAmount(pub f64);
15
16impl FeeAmount {
17    /// Converts the fee amount to satoshis, rounding up.
18    pub fn to_satoshis(&self) -> u64 {
19        self.0.max(0.0).ceil() as u64
20    }
21}
22
23impl std::ops::Add for FeeAmount {
24    type Output = Self;
25
26    fn add(self, other: Self) -> Self {
27        FeeAmount(self.0 + other.0)
28    }
29}
30
31impl std::ops::AddAssign for FeeAmount {
32    fn add_assign(&mut self, other: Self) {
33        self.0 += other.0;
34    }
35}
36
37impl From<f64> for FeeAmount {
38    fn from(value: f64) -> Self {
39        FeeAmount(value)
40    }
41}
42
43/// Type of VTXO input.
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
45pub enum VtxoType {
46    #[default]
47    Vtxo,
48    Recoverable,
49    Note,
50}
51
52impl VtxoType {
53    /// Returns the string representation used in CEL expressions.
54    pub fn as_str(&self) -> &'static str {
55        match self {
56            VtxoType::Vtxo => "vtxo",
57            VtxoType::Recoverable => "recoverable",
58            VtxoType::Note => "note",
59        }
60    }
61}
62
63impl std::str::FromStr for VtxoType {
64    type Err = String;
65
66    fn from_str(s: &str) -> Result<Self, Self::Err> {
67        match s {
68            "vtxo" => Ok(VtxoType::Vtxo),
69            "recoverable" => Ok(VtxoType::Recoverable),
70            "note" => Ok(VtxoType::Note),
71            _ => Err(format!("unknown vtxo type: {}", s)),
72        }
73    }
74}
75
76/// An offchain input (VTXO) for fee calculation.
77#[derive(Debug, Clone, Default)]
78pub struct OffchainInput {
79    /// Amount in satoshis.
80    pub amount: u64,
81    /// Expiry time as Unix timestamp in seconds (optional).
82    pub expiry: Option<i64>,
83    /// Birth time as Unix timestamp in seconds (optional).
84    pub birth: Option<i64>,
85    /// Type of the input.
86    pub input_type: VtxoType,
87    /// Weighted liquidity lockup ratio.
88    pub weight: f64,
89}
90
91/// An onchain input (boarding) for fee calculation.
92#[derive(Debug, Clone, Default)]
93pub struct OnchainInput {
94    /// Amount in satoshis.
95    pub amount: u64,
96}
97
98/// An output for fee calculation.
99#[derive(Debug, Clone, Default)]
100pub struct Output {
101    /// Amount in satoshis.
102    pub amount: u64,
103    /// Hex encoded pkscript.
104    pub script: String,
105}
106
107/// Configuration for the fee estimator.
108#[derive(Debug, Clone, Default)]
109pub struct Config {
110    /// CEL program for offchain input fees.
111    pub intent_offchain_input_program: String,
112    /// CEL program for onchain input fees.
113    pub intent_onchain_input_program: String,
114    /// CEL program for offchain output fees.
115    pub intent_offchain_output_program: String,
116    /// CEL program for onchain output fees.
117    pub intent_onchain_output_program: String,
118}
119
120/// A compiled CEL program that can be evaluated.
121struct CompiledProgram {
122    program: Program,
123    #[allow(dead_code)]
124    source: String,
125}
126
127impl std::fmt::Debug for CompiledProgram {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        f.debug_struct("CompiledProgram")
130            .field("source", &self.source)
131            .finish_non_exhaustive()
132    }
133}
134
135/// Fee estimator using CEL expressions.
136#[derive(Debug)]
137pub struct Estimator {
138    intent_offchain_input: Option<CompiledProgram>,
139    intent_onchain_input: Option<CompiledProgram>,
140    intent_offchain_output: Option<CompiledProgram>,
141    intent_onchain_output: Option<CompiledProgram>,
142}
143
144/// Error type for fee estimation.
145#[derive(Debug)]
146pub enum Error {
147    /// Error compiling CEL program.
148    Compile(String),
149    /// Error evaluating CEL program.
150    Eval(String),
151    /// Unexpected return type from CEL program.
152    ReturnType(String),
153}
154
155impl std::fmt::Display for Error {
156    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157        match self {
158            Error::Compile(msg) => write!(f, "compile error: {}", msg),
159            Error::Eval(msg) => write!(f, "eval error: {}", msg),
160            Error::ReturnType(msg) => write!(f, "{}", msg),
161        }
162    }
163}
164
165impl std::error::Error for Error {}
166
167/// Returns the current Unix timestamp in seconds.
168fn now() -> f64 {
169    SystemTime::now()
170        .duration_since(UNIX_EPOCH)
171        .unwrap_or_default()
172        .as_secs() as f64
173}
174
175/// Environment type for program validation.
176#[derive(Debug, Clone, Copy)]
177enum ProgramType {
178    OffchainInput,
179    OnchainInput,
180    Output,
181}
182
183/// Compiles a CEL program and validates it returns a numeric type.
184fn compile_program(source: &str, program_type: ProgramType) -> Result<CompiledProgram, Error> {
185    let program = Program::compile(source).map_err(|e| Error::Compile(format!("{}", e)))?;
186
187    // Validate by doing a dry run with dummy values
188    let context = create_validation_context(program_type);
189    let result = program
190        .execute(&context)
191        .map_err(|e| Error::Compile(format!("{}", e)))?;
192
193    // Verify the return type is double (float)
194    // We strictly require Float to match cel-go's behavior
195    match result {
196        cel::Value::Float(_) => {}
197        cel::Value::Int(_)
198        | cel::Value::UInt(_)
199        | cel::Value::String(_)
200        | cel::Value::Bytes(_)
201        | cel::Value::Bool(_)
202        | cel::Value::List(_)
203        | cel::Value::Map(_)
204        | cel::Value::Null
205        | cel::Value::Duration(_)
206        | cel::Value::Timestamp(_)
207        | cel::Value::Function(_, _)
208        | cel::Value::Opaque(_) => {
209            return Err(Error::ReturnType(format!(
210                "expected return type double, got {:?}",
211                result
212            )));
213        }
214    }
215
216    Ok(CompiledProgram {
217        program,
218        source: source.to_string(),
219    })
220}
221
222/// Creates a context with dummy values for program validation.
223fn create_validation_context(program_type: ProgramType) -> Context<'static> {
224    let mut context = Context::default();
225
226    match program_type {
227        ProgramType::OffchainInput => {
228            let _ = context.add_variable("amount", 0.0_f64);
229            let _ = context.add_variable("inputType", "vtxo");
230            let _ = context.add_variable("weight", 0.0_f64);
231            let _ = context.add_variable("expiry", 0.0_f64);
232            let _ = context.add_variable("birth", 0.0_f64);
233        }
234        ProgramType::OnchainInput => {
235            let _ = context.add_variable("amount", 0.0_f64);
236        }
237        ProgramType::Output => {
238            let _ = context.add_variable("amount", 0.0_f64);
239            let _ = context.add_variable("script", "");
240        }
241    }
242
243    context.add_function("now", now);
244    context
245}
246
247/// Creates a CEL context for offchain input evaluation.
248fn create_offchain_input_context(input: &OffchainInput) -> Context<'static> {
249    let mut context = Context::default();
250    let _ = context.add_variable("amount", input.amount as f64);
251    let _ = context.add_variable("inputType", input.input_type.as_str());
252    let _ = context.add_variable("weight", input.weight);
253
254    // Always add expiry and birth to match validation context.
255    // Default to 0.0 (Unix epoch) when not provided.
256    let _ = context.add_variable("expiry", input.expiry.unwrap_or(0) as f64);
257    let _ = context.add_variable("birth", input.birth.unwrap_or(0) as f64);
258
259    context.add_function("now", now);
260    context
261}
262
263/// Creates a CEL context for onchain input evaluation.
264fn create_onchain_input_context(input: &OnchainInput) -> Context<'static> {
265    let mut context = Context::default();
266    let _ = context.add_variable("amount", input.amount as f64);
267    context.add_function("now", now);
268    context
269}
270
271/// Creates a CEL context for output evaluation.
272fn create_output_context(output: &Output) -> Context<'static> {
273    let mut context = Context::default();
274    let _ = context.add_variable("amount", output.amount as f64);
275    let _ = context.add_variable("script", output.script.clone());
276    context.add_function("now", now);
277    context
278}
279
280impl Estimator {
281    /// Creates a new fee estimator from the given configuration.
282    ///
283    /// Programs are optional; if empty, the corresponding fee evaluation returns 0.
284    pub fn new(config: Config) -> Result<Self, Error> {
285        let intent_offchain_input = if !config.intent_offchain_input_program.is_empty() {
286            Some(compile_program(
287                &config.intent_offchain_input_program,
288                ProgramType::OffchainInput,
289            )?)
290        } else {
291            None
292        };
293
294        let intent_onchain_input = if !config.intent_onchain_input_program.is_empty() {
295            Some(compile_program(
296                &config.intent_onchain_input_program,
297                ProgramType::OnchainInput,
298            )?)
299        } else {
300            None
301        };
302
303        let intent_offchain_output = if !config.intent_offchain_output_program.is_empty() {
304            Some(compile_program(
305                &config.intent_offchain_output_program,
306                ProgramType::Output,
307            )?)
308        } else {
309            None
310        };
311
312        let intent_onchain_output = if !config.intent_onchain_output_program.is_empty() {
313            Some(compile_program(
314                &config.intent_onchain_output_program,
315                ProgramType::Output,
316            )?)
317        } else {
318            None
319        };
320
321        Ok(Estimator {
322            intent_offchain_input,
323            intent_onchain_input,
324            intent_offchain_output,
325            intent_onchain_output,
326        })
327    }
328
329    /// Evaluates the fee for a given offchain input (VTXO).
330    pub fn eval_offchain_input(&self, input: OffchainInput) -> Result<FeeAmount, Error> {
331        match &self.intent_offchain_input {
332            Some(compiled) => {
333                let context = create_offchain_input_context(&input);
334                let result = compiled
335                    .program
336                    .execute(&context)
337                    .map_err(|e| Error::Eval(format!("{}", e)))?;
338
339                match result {
340                    cel::Value::Float(f) => Ok(FeeAmount(f)),
341                    cel::Value::Int(i) => Ok(FeeAmount(i as f64)),
342                    cel::Value::UInt(u) => Ok(FeeAmount(u as f64)),
343                    cel::Value::String(_)
344                    | cel::Value::Bytes(_)
345                    | cel::Value::Bool(_)
346                    | cel::Value::List(_)
347                    | cel::Value::Map(_)
348                    | cel::Value::Null
349                    | cel::Value::Duration(_)
350                    | cel::Value::Timestamp(_)
351                    | cel::Value::Function(_, _)
352                    | cel::Value::Opaque(_) => Err(Error::ReturnType(format!(
353                        "expected return type double, got {:?}",
354                        result
355                    ))),
356                }
357            }
358            None => Ok(FeeAmount(0.0)),
359        }
360    }
361
362    /// Evaluates the fee for a given onchain input (boarding).
363    pub fn eval_onchain_input(&self, input: OnchainInput) -> Result<FeeAmount, Error> {
364        match &self.intent_onchain_input {
365            Some(compiled) => {
366                let context = create_onchain_input_context(&input);
367                let result = compiled
368                    .program
369                    .execute(&context)
370                    .map_err(|e| Error::Eval(format!("{}", e)))?;
371
372                match result {
373                    cel::Value::Float(f) => Ok(FeeAmount(f)),
374                    cel::Value::Int(i) => Ok(FeeAmount(i as f64)),
375                    cel::Value::UInt(u) => Ok(FeeAmount(u as f64)),
376                    cel::Value::String(_)
377                    | cel::Value::Bytes(_)
378                    | cel::Value::Bool(_)
379                    | cel::Value::List(_)
380                    | cel::Value::Map(_)
381                    | cel::Value::Null
382                    | cel::Value::Duration(_)
383                    | cel::Value::Timestamp(_)
384                    | cel::Value::Function(_, _)
385                    | cel::Value::Opaque(_) => Err(Error::ReturnType(format!(
386                        "expected return type double, got {:?}",
387                        result
388                    ))),
389                }
390            }
391            None => Ok(FeeAmount(0.0)),
392        }
393    }
394
395    /// Evaluates the fee for a given offchain output (VTXO).
396    pub fn eval_offchain_output(&self, output: Output) -> Result<FeeAmount, Error> {
397        match &self.intent_offchain_output {
398            Some(compiled) => {
399                let context = create_output_context(&output);
400                let result = compiled
401                    .program
402                    .execute(&context)
403                    .map_err(|e| Error::Eval(format!("{}", e)))?;
404
405                match result {
406                    cel::Value::Float(f) => Ok(FeeAmount(f)),
407                    cel::Value::Int(i) => Ok(FeeAmount(i as f64)),
408                    cel::Value::UInt(u) => Ok(FeeAmount(u as f64)),
409                    cel::Value::String(_)
410                    | cel::Value::Bytes(_)
411                    | cel::Value::Bool(_)
412                    | cel::Value::List(_)
413                    | cel::Value::Map(_)
414                    | cel::Value::Null
415                    | cel::Value::Duration(_)
416                    | cel::Value::Timestamp(_)
417                    | cel::Value::Function(_, _)
418                    | cel::Value::Opaque(_) => Err(Error::ReturnType(format!(
419                        "expected return type double, got {:?}",
420                        result
421                    ))),
422                }
423            }
424            None => Ok(FeeAmount(0.0)),
425        }
426    }
427
428    /// Evaluates the fee for a given onchain output (collaborative exit).
429    pub fn eval_onchain_output(&self, output: Output) -> Result<FeeAmount, Error> {
430        match &self.intent_onchain_output {
431            Some(compiled) => {
432                let context = create_output_context(&output);
433                let result = compiled
434                    .program
435                    .execute(&context)
436                    .map_err(|e| Error::Eval(format!("{}", e)))?;
437
438                match result {
439                    cel::Value::Float(f) => Ok(FeeAmount(f)),
440                    cel::Value::Int(i) => Ok(FeeAmount(i as f64)),
441                    cel::Value::UInt(u) => Ok(FeeAmount(u as f64)),
442                    cel::Value::String(_)
443                    | cel::Value::Bytes(_)
444                    | cel::Value::Bool(_)
445                    | cel::Value::List(_)
446                    | cel::Value::Map(_)
447                    | cel::Value::Null
448                    | cel::Value::Duration(_)
449                    | cel::Value::Timestamp(_)
450                    | cel::Value::Function(_, _)
451                    | cel::Value::Opaque(_) => Err(Error::ReturnType(format!(
452                        "expected return type double, got {:?}",
453                        result
454                    ))),
455                }
456            }
457            None => Ok(FeeAmount(0.0)),
458        }
459    }
460
461    /// Evaluates the total fee for a given set of inputs and outputs.
462    pub fn eval(
463        &self,
464        offchain_inputs: &[OffchainInput],
465        onchain_inputs: &[OnchainInput],
466        offchain_outputs: &[Output],
467        onchain_outputs: &[Output],
468    ) -> Result<FeeAmount, Error> {
469        let mut fee = FeeAmount(0.0);
470
471        for input in offchain_inputs {
472            fee += self.eval_offchain_input(input.clone())?;
473        }
474
475        for input in onchain_inputs {
476            fee += self.eval_onchain_input(input.clone())?;
477        }
478
479        for output in offchain_outputs {
480            fee += self.eval_offchain_output(output.clone())?;
481        }
482
483        for output in onchain_outputs {
484            fee += self.eval_onchain_output(output.clone())?;
485        }
486
487        Ok(fee)
488    }
489}