1use std::time::Duration;
4
5use encdec::{EncDec, Encode};
6use tracing::{debug, error};
7
8use ledger_proto::{
9 apdus::{AppInfoReq, AppInfoResp, DeviceInfoReq, DeviceInfoResp},
10 ApduError, ApduReq, StatusCode,
11};
12
13use crate::{
14 info::{AppInfo, DeviceInfo},
15 Error, Exchange,
16};
17
18const APDU_BUFF_LEN: usize = 256;
19
20#[cfg_attr(not(feature = "unstable_async_trait"), async_trait::async_trait)]
22pub trait Device {
23 async fn request<'a, 'b, RESP: EncDec<'b, ApduError>>(
25 &mut self,
26 request: impl ApduReq<'a> + Send,
27 buff: &'b mut [u8],
28 timeout: Duration,
29 ) -> Result<RESP, Error>;
30
31 async fn app_info(&mut self, timeout: Duration) -> Result<AppInfo, Error> {
33 let mut buff = [0u8; APDU_BUFF_LEN];
34
35 let r = self
36 .request::<AppInfoResp>(AppInfoReq {}, &mut buff[..], timeout)
37 .await?;
38
39 Ok(AppInfo {
40 name: r.name.to_string(),
41 version: r.version.to_string(),
42 flags: r.flags,
43 })
44 }
45
46 async fn device_info(&mut self, timeout: Duration) -> Result<DeviceInfo, Error> {
48 let mut buff = [0u8; APDU_BUFF_LEN];
49
50 let r = self
51 .request::<DeviceInfoResp>(DeviceInfoReq {}, &mut buff[..], timeout)
52 .await?;
53
54 Ok(DeviceInfo {
55 target_id: r.target_id,
56 se_version: r.se_version.to_string(),
57 mcu_version: r.mcu_version.to_string(),
58 flags: r.flags.to_vec(),
59 })
60 }
61}
62
63#[cfg_attr(not(feature = "unstable_async_trait"), async_trait::async_trait)]
65impl<T: Exchange + Send> Device for T {
66 async fn request<'a, 'b, RESP: EncDec<'b, ApduError>>(
68 &mut self,
69 req: impl ApduReq<'a> + Send,
70 buff: &'b mut [u8],
71 timeout: Duration,
72 ) -> Result<RESP, Error> {
73 debug!("TX: {req:?}");
74
75 let n = encode_request(req, buff)?;
77
78 let resp_bytes = self.exchange(&buff[..n], timeout).await?;
80
81 let n = resp_bytes.len();
84 if n > buff.len() {
85 error!(
86 "Response length exceeds buffer length ({} > {})",
87 n,
88 buff.len()
89 );
90 return Err(ApduError::InvalidLength.into());
91 }
92 buff[..n].copy_from_slice(&resp_bytes[..]);
93
94 if n == 2 {
96 let v = u16::from_be_bytes([resp_bytes[0], resp_bytes[1]]);
98 match StatusCode::try_from(v) {
99 Ok(c) => return Err(Error::Status(c)),
100 Err(_) => return Err(Error::UnknownStatus(resp_bytes[0], resp_bytes[1])),
101 }
102 }
103
104 let (resp, _) = RESP::decode(&buff[..n - 2])?;
106
107 debug!("RX: {resp:?}");
108
109 Ok(resp)
111 }
112}
113
114fn encode_request<'a, REQ: ApduReq<'a>>(req: REQ, buff: &mut [u8]) -> Result<usize, Error> {
116 let mut index = 0;
117
118 let data_len = req.encode_len()?;
119
120 if buff.len() < 5 + data_len {
122 return Err(ApduError::InvalidLength.into());
123 }
124
125 let h = req.header();
129 index += h.encode(&mut buff[index..])?;
130
131 if data_len > u8::MAX as usize {
133 return Err(ApduError::InvalidLength.into());
134 }
135 buff[index] = data_len as u8;
136 index += 1;
137
138 index += req.encode(&mut buff[index..])?;
140
141 Ok(index)
142}
143
144#[cfg(test)]
145mod tests {
146 use ledger_proto::{apdus::AppInfoReq, ApduStatic};
147
148 use super::encode_request;
149
150 #[test]
151 fn test_encode_requests() {
152 let mut buff = [0u8; 256];
153
154 let req = AppInfoReq {};
155 let n = encode_request(req, &mut buff).unwrap();
156 assert_eq!(n, 5);
157 assert_eq!(
158 &buff[..n],
159 &[AppInfoReq::CLA, AppInfoReq::INS, 0x00, 0x00, 0x00]
160 );
161 }
162}