emulator_connect/servers/
hd.rs

1// Copyright Judica, Inc 2021
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4//  License, v. 2.0. If a copy of the MPL was not distributed with this
5//  file, You can obtain one at https://mozilla.org/MPL/2.0/.
6
7//! definitions for oracle servers
8use super::*;
9use bitcoin::util::sighash::Prevouts;
10use bitcoin::util::taproot::TapLeafHash;
11use bitcoin::util::taproot::TapSighashHash;
12use bitcoin::SchnorrSig;
13use bitcoin::Script;
14use bitcoin::TxOut;
15use bitcoin::XOnlyPublicKey;
16
17/// hierarchical deterministic oracle emulator
18#[derive(Clone)]
19pub struct HDOracleEmulator {
20    root: ExtendedPrivKey,
21    debug: bool,
22}
23
24impl HDOracleEmulator {
25    /// create a new HDOracleEmulator
26    ///
27    /// if debug is set, runs in a "single threaded" mode where we can observe errors on connections rather than ignoring them.
28    pub fn new(root: ExtendedPrivKey, debug: bool) -> Self {
29        HDOracleEmulator { root, debug }
30    }
31    /// binds a HDOracleEmulator to a socket interface and runs the server
32    ///
33    /// This will only return when debug = false if The TcpListener fails.
34    /// When debug = true, then we join each connection one at a time and return
35    /// any errors.
36    pub async fn bind<A: ToSocketAddrs>(self, a: A) -> std::io::Result<()> {
37        let listener = TcpListener::bind(a).await?;
38        loop {
39            let (mut socket, _) = listener.accept().await?;
40            {
41                let this = self.clone();
42                let j: tokio::task::JoinHandle<Result<(), std::io::Error>> =
43                    tokio::spawn(async move {
44                        loop {
45                            socket.readable().await?;
46                            this.handle(&mut socket).await?;
47                        }
48                    });
49                if self.debug {
50                    tokio::join!(j).0??;
51                }
52            }
53        }
54    }
55    /// helper to get an EPK for the oracle.
56    fn derive(&self, h: Sha256, secp: &Secp256k1<All>) -> Result<ExtendedPrivKey, Error> {
57        let c = hash_to_child_vec(h);
58        self.root.derive_priv(secp, &c)
59    }
60
61    /// Signs a PSBT with the correct derived key.
62    ///
63    /// Always signs for spending index 0.
64    ///
65    /// May fail to sign if the PSBT is not properly formatted
66    fn sign(
67        &self,
68        mut b: PartiallySignedTransaction,
69        secp: &Secp256k1<All>,
70    ) -> Result<PartiallySignedTransaction, std::io::Error> {
71        let tx = b.clone().extract_tx();
72        let h = tx.get_ctv_hash(0);
73        let utxos: Vec<TxOut> = b
74            .inputs
75            .iter()
76            .map(|o| o.witness_utxo.clone())
77            .collect::<Option<Vec<TxOut>>>()
78            .ok_or_else(|| input_err("Could not find one of the UTXOs to be signed over"))?;
79        let key = self
80            .derive(h, secp)
81            .map_err(|_| input_err("Could Not Derive Key"))?;
82        let untweaked = key.to_keypair(secp);
83        let pk = XOnlyPublicKey::from_keypair(&untweaked);
84        let mut sighash = bitcoin::util::sighash::SighashCache::new(&tx);
85        let input_zero = &mut b.inputs[0];
86        use bitcoin::schnorr::TapTweak;
87        let tweaked = untweaked
88            .tap_tweak(secp, input_zero.tap_merkle_root)
89            .into_inner();
90        let tweaked_pk = tweaked.public_key();
91        let hash_ty = bitcoin::util::sighash::SchnorrSighashType::All;
92        let prevouts = &Prevouts::All(&utxos);
93        let mut get_sig = |path, kp| {
94            let annex = None;
95            let sighash: TapSighashHash = sighash
96                .taproot_signature_hash(0, prevouts, annex, path, hash_ty)
97                .expect("Signature hash cannot fail...");
98            let msg = bitcoin::secp256k1::Message::from_slice(&sighash[..])
99                .expect("Size must be correct.");
100            let sig = secp.sign_schnorr_no_aux_rand(&msg, kp);
101            SchnorrSig { sig, hash_ty }
102        };
103        if let Some(true) = input_zero.witness_utxo.as_ref().map(|v| {
104            v.script_pubkey
105                == Script::new_v1_p2tr_tweaked(
106                    XOnlyPublicKey::from(tweaked_pk).dangerous_assume_tweaked(),
107                )
108        }) {
109            let sig = get_sig(None, &tweaked);
110            input_zero.tap_key_sig = Some(sig);
111        }
112        for tlh in input_zero
113            .tap_scripts
114            .values()
115            .map(|(script, ver)| TapLeafHash::from_script(script, *ver))
116        {
117            let sig = get_sig(Some((tlh, 0xffffffff)), &untweaked);
118            input_zero.tap_script_sigs.insert((pk.0, tlh), sig);
119        }
120        Ok(b)
121    }
122
123    /// the main server business logic.
124    ///
125    /// - on receiving Request::SignPSBT, signs the PSBT.
126    async fn handle(&self, t: &mut TcpStream) -> Result<(), std::io::Error> {
127        let request = Self::requested(t).await?;
128        match request {
129            msgs::Request::SignPSBT(msgs::PSBT(unsigned)) => {
130                let psbt = SECP.with(|secp| self.sign(unsigned, secp))?;
131                Self::respond(t, &msgs::PSBT(psbt)).await
132            }
133        }
134    }
135
136    /// receive a request via the tcpstream.
137    /// wire format: length:u32 data:[u8;length]
138    ///
139    /// TODO: DoS Critical: limit the allowed max length we will attempt to derserialize
140    async fn requested(t: &mut TcpStream) -> Result<msgs::Request, std::io::Error> {
141        let l = t.read_u32().await? as usize;
142        let mut v = vec![0u8; l];
143        t.read_exact(&mut v[..]).await?;
144        Ok(serde_json::from_slice(&v[..])?)
145    }
146
147    /// respond via the tcpstream.
148    /// wire format: length:u32 data:[u8;length]
149    async fn respond<T: Serialize>(t: &mut TcpStream, r: &T) -> Result<(), std::io::Error> {
150        let v = serde_json::to_vec(r)?;
151        t.write_u32(v.len() as u32).await?;
152        t.write_all(&v[..]).await?;
153        t.flush().await
154    }
155}