Skip to main content

async_hwi/jade/
mod.rs

1pub mod api;
2pub mod pinserver;
3
4use std::{
5    collections::BTreeMap,
6    fmt::Debug,
7    str::FromStr,
8    sync::Arc,
9    time::{SystemTime, UNIX_EPOCH},
10};
11
12use serde::{de::DeserializeOwned, Serialize};
13
14use bitcoin::{
15    bip32::{DerivationPath, Fingerprint, Xpub},
16    psbt::Psbt,
17    Network,
18};
19
20use serialport::{available_ports, SerialPort, SerialPortType};
21use tokio::{
22    io::{AsyncRead, AsyncReadExt, AsyncWriteExt},
23    sync::Mutex,
24};
25use tokio_serial::SerialPortBuilderExt;
26
27pub use tokio_serial::SerialStream;
28
29use crate::{parse_version, utils};
30
31use super::{AddressScript, DeviceKind, Error as HWIError, HWI};
32use async_trait::async_trait;
33
34pub const JADE_NETWORK_MAINNET: &str = "mainnet";
35pub const JADE_NETWORK_TESTNET: &str = "testnet";
36
37#[derive(Debug)]
38pub struct Jade<T> {
39    transport: T,
40    network: &'static str,
41    kind: DeviceKind,
42    descriptor_name: Option<String>,
43}
44
45impl<T: Transport + Sync + Send> Jade<T> {
46    pub fn new(transport: T) -> Self {
47        Self {
48            transport,
49            network: JADE_NETWORK_MAINNET,
50            kind: DeviceKind::Jade,
51            descriptor_name: None,
52        }
53    }
54
55    pub fn with_network(mut self, network: Network) -> Self {
56        if network == Network::Bitcoin {
57            self.network = JADE_NETWORK_MAINNET;
58        } else {
59            self.network = JADE_NETWORK_TESTNET;
60        }
61        self
62    }
63
64    pub fn with_wallet(mut self, descriptor_name: String) -> Self {
65        self.descriptor_name = Some(descriptor_name);
66        self
67    }
68
69    pub async fn ping(&self) -> Result<(), JadeError> {
70        let _res: u64 = self
71            .transport
72            .request("ping", Option::<api::EmptyRequest>::None)
73            .await?
74            .into_result()?;
75        Ok(())
76    }
77
78    pub async fn get_info(&self) -> Result<api::GetInfoResponse, HWIError> {
79        let info: api::GetInfoResponse = self
80            .transport
81            .request("get_version_info", Option::<api::EmptyRequest>::None)
82            .await?
83            .into_result()?;
84        Ok(info)
85    }
86
87    pub async fn get_registered_descriptors(
88        &self,
89    ) -> Result<BTreeMap<String, api::DescriptorInfoResponse>, HWIError> {
90        let descriptors: BTreeMap<String, api::DescriptorInfoResponse> = self
91            .transport
92            .request(
93                "get_registered_descriptors",
94                Option::<api::EmptyRequest>::None,
95            )
96            .await?
97            .into_result()?;
98        Ok(descriptors)
99    }
100
101    pub async fn get_registered_descriptor(
102        &self,
103        name: &str,
104    ) -> Result<api::GetRegisteredDescriptorResponse, HWIError> {
105        let registered: api::GetRegisteredDescriptorResponse = self
106            .transport
107            .request(
108                "get_registered_descriptor",
109                Some(api::GetRegisteredDescriptorParams {
110                    descriptor_name: name,
111                }),
112            )
113            .await?
114            .into_result()?;
115        Ok(registered)
116    }
117
118    pub async fn auth(&self) -> Result<(), JadeError> {
119        let res: api::AuthUserResponse = self
120            .transport
121            .request(
122                "auth_user",
123                Some(api::AuthUserParams {
124                    network: self.network,
125                    epoch: SystemTime::now()
126                        .duration_since(UNIX_EPOCH)
127                        .ok()
128                        .map(|t| t.as_secs())
129                        .unwrap_or(0),
130                }),
131            )
132            .await?
133            .into_result()?;
134
135        if let api::AuthUserResponse::PinServerRequired { http_request } = res {
136            let client = pinserver::PinServerClient::new();
137            let pin_params: api::PinParams = client.request(http_request.params).await?;
138            let handshake_completed: bool = self
139                .transport
140                .request("pin", Some(pin_params))
141                .await?
142                .into_result()?;
143            if !handshake_completed {
144                return Err(JadeError::HandShakeRefused);
145            }
146        }
147        Ok(())
148    }
149}
150
151#[async_trait]
152impl<T: Transport + Sync + Send> HWI for Jade<T> {
153    fn device_kind(&self) -> DeviceKind {
154        self.kind
155    }
156
157    async fn get_version(&self) -> Result<super::Version, HWIError> {
158        let info = self.get_info().await?;
159        parse_version(&info.jade_version)
160    }
161
162    async fn get_master_fingerprint(&self) -> Result<Fingerprint, HWIError> {
163        let xpub = self.get_extended_pubkey(&DerivationPath::master()).await?;
164        Ok(xpub.fingerprint())
165    }
166
167    async fn get_extended_pubkey(&self, path: &DerivationPath) -> Result<Xpub, HWIError> {
168        let s: String = self
169            .transport
170            .request(
171                "get_xpub",
172                Some(api::GetXpubParams {
173                    network: self.network,
174                    path: path.to_u32_vec(),
175                }),
176            )
177            .await?
178            .into_result()?;
179        let xpub = Xpub::from_str(&s).map_err(|e| HWIError::Device(e.to_string()))?;
180        Ok(xpub)
181    }
182
183    async fn display_address(&self, script: &AddressScript) -> Result<(), HWIError> {
184        match (self.descriptor_name.as_ref(), script) {
185            (Some(descriptor_name), AddressScript::Miniscript { index, change }) => {
186                let _address: String = self
187                    .transport
188                    .request(
189                        "get_receive_address",
190                        Some(api::DescriptorAddressParams {
191                            network: self.network,
192                            branch: u32::from(*change),
193                            pointer: *index,
194                            descriptor_name,
195                        }),
196                    )
197                    .await?
198                    .into_result()?;
199                Ok(())
200            }
201            _ => Err(HWIError::UnimplementedMethod),
202        }
203    }
204
205    async fn register_wallet(
206        &self,
207        name: &str,
208        policy: &str,
209    ) -> Result<Option<[u8; 32]>, HWIError> {
210        let (descriptor_template, keys) = utils::extract_keys_and_template::<String>(policy)?;
211        let registered: bool = self
212            .transport
213            .request(
214                "register_descriptor",
215                Some(api::RegisterDescriptorParams {
216                    network: self.network,
217                    descriptor_name: name,
218                    descriptor: descriptor_template,
219                    datavalues: keys
220                        .into_iter()
221                        .enumerate()
222                        .map(|(i, key)| (format!("@{}", i), key))
223                        .collect(),
224                }),
225            )
226            .await?
227            .into_result()?;
228        if !registered {
229            Err(HWIError::UserRefused)
230        } else {
231            Ok(None)
232        }
233    }
234
235    async fn is_wallet_registered(&self, name: &str, policy: &str) -> Result<bool, HWIError> {
236        let registered_descriptors = self.get_registered_descriptors().await?;
237        if !registered_descriptors.contains_key(name) {
238            return Ok(false);
239        }
240
241        let registered = self.get_registered_descriptor(name).await?;
242
243        let (descriptor_template, keys) = utils::extract_keys_and_template::<String>(policy)?;
244        let datavalues: BTreeMap<String, String> = keys
245            .into_iter()
246            .enumerate()
247            .map(|(i, key)| (format!("@{}", i), key))
248            .collect();
249
250        Ok(registered.descriptor_name == name
251            && registered.descriptor == descriptor_template
252            && registered.datavalues == datavalues)
253    }
254
255    async fn sign_tx(&self, psbt: &mut Psbt) -> Result<(), HWIError> {
256        let first: api::Response<serde_bytes::ByteBuf> = self
257            .transport
258            .request(
259                "sign_psbt",
260                Some(api::SignPsbtParams {
261                    network: self.network,
262                    psbt: Psbt::serialize(psbt),
263                }),
264            )
265            .await?;
266
267        if let Some(e) = first.error {
268            return Err(JadeError::Rpc(e).into());
269        }
270
271        let mut psbt_bytes = first
272            .result
273            .ok_or(JadeError::Transport(TransportError::NoErrorOrResult))?;
274
275        if let (Some(mut seqlen), Some(mut seqnum)) = (first.seqlen, first.seqnum) {
276            if seqlen > 1 {
277                while seqnum < seqlen {
278                    let mut res: api::Response<serde_bytes::ByteBuf> = self
279                        .transport
280                        .request(
281                            "get_extended_data",
282                            Some(api::GetExtendedDataParams {
283                                origid: &first.id,
284                                orig: "sign_psbt",
285                                seqnum: seqnum + 1,
286                                seqlen,
287                            }),
288                        )
289                        .await?;
290
291                    if let Some(e) = res.error {
292                        return Err(JadeError::Rpc(e).into());
293                    }
294
295                    if let Some(bytes) = res.result.as_mut() {
296                        psbt_bytes.append(bytes);
297                    } else {
298                        return Err(JadeError::Transport(TransportError::NoErrorOrResult).into());
299                    }
300
301                    if let (Some(len), Some(num)) = (res.seqlen, res.seqnum) {
302                        seqlen = len;
303                        seqnum = num;
304                    } else {
305                        return Err(JadeError::Transport(TransportError::NoErrorOrResult).into());
306                    }
307                }
308            }
309        }
310
311        let signed_psbt =
312            Psbt::deserialize(&psbt_bytes).map_err(|e| HWIError::Device(e.to_string()))?;
313        utils::merge_signatures(psbt, &signed_psbt);
314
315        Ok(())
316    }
317}
318
319impl<T: 'static + Transport + Sync + Send> From<Jade<T>> for Box<dyn HWI + Send> {
320    fn from(s: Jade<T>) -> Box<dyn HWI + Send> {
321        Box::new(s)
322    }
323}
324
325async fn exchange<S, D>(
326    transport: &mut SerialStream,
327    method: &str,
328    params: Option<S>,
329) -> Result<api::Response<D>, JadeError>
330where
331    S: Serialize + Unpin,
332    D: DeserializeOwned + Unpin,
333{
334    let (reader, mut writer) = tokio::io::split(transport);
335
336    let id = std::process::id();
337    let req = serde_cbor::to_vec(&api::Request {
338        id: &id.to_string(),
339        method,
340        params,
341    })
342    .map_err(TransportError::from)?;
343
344    writer.write_all(&req).await.map_err(TransportError::from)?;
345
346    let response = read_stream(reader).await?;
347
348    if response.id != id.to_string() {
349        return Err(TransportError::NonceMismatch.into());
350    }
351
352    Ok(response)
353}
354
355async fn read_stream<D: DeserializeOwned, S: AsyncRead + Unpin>(
356    mut stream: S,
357) -> Result<api::Response<D>, TransportError> {
358    let mut buf = Vec::<u8>::new();
359    let mut chunk = [0; 1024];
360    let n = stream.read(&mut chunk).await?;
361    buf.extend_from_slice(&chunk[..n]);
362    if let Ok(response) = serde_cbor::from_slice(&buf) {
363        return Ok(response);
364    }
365    loop {
366        tokio::select! {
367            res = stream.read(&mut chunk) => {
368                let n = res?;
369                if n == 0 {
370                    break;
371                }
372                buf.extend_from_slice(&chunk[..n]);
373                if let Ok(response) = serde_cbor::from_slice(&buf) {
374                    return Ok(response);
375                }
376            }
377            _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => {
378                break;
379            }
380        }
381    }
382    match serde_cbor::from_slice(&buf) {
383        Ok(response) => Ok(response),
384        Err(_) => Err(TransportError::NoErrorOrResult),
385    }
386}
387
388#[async_trait]
389pub trait Transport: Debug {
390    async fn request<S: Serialize + Send + Unpin, D: DeserializeOwned + Unpin + Send>(
391        &self,
392        method: &str,
393        params: Option<S>,
394    ) -> Result<api::Response<D>, JadeError>;
395}
396
397impl Jade<SerialTransport> {
398    pub async fn enumerate() -> Result<Vec<Self>, JadeError> {
399        let mut res = Vec::new();
400        for port_name in SerialTransport::enumerate_potential_ports()? {
401            let jade = Jade::<SerialTransport>::new(SerialTransport::new(port_name)?);
402            jade.ping().await?;
403            res.push(jade);
404        }
405        Ok(res)
406    }
407}
408
409#[derive(Debug)]
410pub struct SerialTransport {
411    pub stream: Arc<Mutex<SerialStream>>,
412}
413
414pub const JADE_DEVICE_IDS: [(u16, u16); 6] = [
415    (0x10c4, 0xea60),
416    (0x1a86, 0x55d4),
417    (0x0403, 0x6001),
418    (0x1a86, 0x7523),
419    (0x303a, 0x4001),
420    (0x303a, 0x1001),
421];
422
423impl SerialTransport {
424    pub fn new(port_name: String) -> Result<Self, TransportError> {
425        let mut transport = tokio_serial::new(port_name, DEFAULT_JADE_BAUD_RATE)
426            .open_native_async()
427            .map_err(TransportError::from)?;
428        // Ensure RTS and DTR are not set (as this can cause the hw to reboot)
429        // according to https://github.com/Blockstream/Jade/blob/master/jadepy/jade_serial.py#L56
430        transport
431            .write_request_to_send(false)
432            .map_err(TransportError::from)?;
433        transport
434            .write_data_terminal_ready(false)
435            .map_err(TransportError::from)?;
436        Ok(Self {
437            stream: Arc::new(Mutex::new(transport)),
438        })
439    }
440    pub fn enumerate_potential_ports() -> Result<Vec<String>, JadeError> {
441        match available_ports() {
442            Ok(ports) => Ok(ports
443                .into_iter()
444                .filter_map(|p| match p.port_type {
445                    SerialPortType::UsbPort(info) => {
446                        if JADE_DEVICE_IDS.contains(&(info.vid, info.pid)) {
447                            Some(p.port_name)
448                        } else {
449                            None
450                        }
451                    }
452                    _ => None,
453                })
454                .collect()),
455            Err(e) => Err(JadeError::Transport(e.into())),
456        }
457    }
458}
459
460const DEFAULT_JADE_BAUD_RATE: u32 = 115200;
461
462#[async_trait]
463impl Transport for SerialTransport {
464    async fn request<S: Serialize + Send + Unpin, D: DeserializeOwned + Unpin + Send>(
465        &self,
466        method: &str,
467        params: Option<S>,
468    ) -> Result<api::Response<D>, JadeError> {
469        let mut stream = self.stream.lock().await;
470        exchange(&mut stream, method, params).await
471    }
472}
473
474#[derive(Debug)]
475pub enum TransportError {
476    Serialize(serde_cbor::Error),
477    NoErrorOrResult,
478    NonceMismatch,
479    Io(std::io::Error),
480    Serial(serialport::Error),
481}
482
483impl From<serde_cbor::Error> for TransportError {
484    fn from(e: serde_cbor::Error) -> Self {
485        Self::Serialize(e)
486    }
487}
488
489impl From<std::io::Error> for TransportError {
490    fn from(e: std::io::Error) -> Self {
491        Self::Io(e)
492    }
493}
494
495impl From<serialport::Error> for TransportError {
496    fn from(e: serialport::Error) -> Self {
497        Self::Serial(e)
498    }
499}
500
501impl std::fmt::Display for TransportError {
502    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
503        match self {
504            Self::Serialize(e) => write!(f, "{}", e),
505            Self::NoErrorOrResult => write!(f, "No Error or Result"),
506            Self::NonceMismatch => write!(f, "Nonce mismatched"),
507            Self::Io(e) => write!(f, "{}", e),
508            Self::Serial(e) => write!(f, "{}", e),
509        }
510    }
511}
512
513#[derive(Debug)]
514pub enum JadeError {
515    Transport(TransportError),
516    Rpc(api::Error),
517    PinServer(pinserver::Error),
518    HandShakeRefused,
519}
520
521impl From<TransportError> for JadeError {
522    fn from(e: TransportError) -> Self {
523        Self::Transport(e)
524    }
525}
526
527impl From<pinserver::Error> for JadeError {
528    fn from(e: pinserver::Error) -> Self {
529        Self::PinServer(e)
530    }
531}
532
533impl std::fmt::Display for JadeError {
534    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
535        match self {
536            Self::Transport(e) => write!(f, "{}", e),
537            Self::Rpc(e) => write!(f, "{:?}", e),
538            Self::PinServer(e) => write!(f, "{:?}", e),
539            Self::HandShakeRefused => write!(f, "Handshake with pinserver refused"),
540        }
541    }
542}
543
544impl From<JadeError> for HWIError {
545    fn from(e: JadeError) -> HWIError {
546        match e {
547            JadeError::Transport(e) => HWIError::Device(e.to_string()),
548            JadeError::Rpc(e) => {
549                if e.code == api::ErrorCode::UserCancelled as i32 {
550                    HWIError::UserRefused
551                } else if e.code == api::ErrorCode::NetworkMismatch as i32 {
552                    HWIError::NetworkMismatch
553                } else {
554                    HWIError::Device(format!("{:?}", e))
555                }
556            }
557            JadeError::PinServer(e) => HWIError::Device(format!("{:?}", e)),
558            JadeError::HandShakeRefused => {
559                HWIError::Device("Handshake with pinserver refused".to_string())
560            }
561        }
562    }
563}