1#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8use nix::{
9 fcntl,
10 poll::{poll, PollFd, PollFlags},
11};
12use std::{
13 cell::RefCell,
14 collections::HashMap,
15 fs::{File, OpenOptions},
16 io::{Read, Write},
17 os::fd::{AsFd, AsRawFd},
18 path::{Path, PathBuf},
19 rc::Rc,
20 time::{Duration, Instant},
21};
22
23use thiserror::Error;
24use tpm2_protocol::{
25 basic::{TpmHandle, TpmUint32},
26 constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
27 data::{
28 Tpm2bName, TpmAlgId, TpmCap, TpmCc, TpmEccCurve, TpmHt, TpmPt, TpmRc, TpmRcBase, TpmSt,
29 TpmsAlgProperty, TpmsAuthCommand, TpmsCapabilityData, TpmsContext, TpmsPcrSelection,
30 TpmtPublic, TpmuCapabilities,
31 },
32 frame::{
33 tpm_marshal_command, tpm_unmarshal_response, TpmAuthResponses, TpmContextLoadCommand,
34 TpmContextSaveCommand, TpmFlushContextCommand, TpmFrame, TpmGetCapabilityCommand,
35 TpmGetCapabilityResponse, TpmReadPublicCommand, TpmResponse,
36 },
37 TpmWriter,
38};
39use tracing::trace;
40
41#[derive(Debug, Error)]
43pub enum TpmDeviceError {
44 #[error("device is already borrowed")]
45 AlreadyBorrowed,
46 #[error("capability not found: {0}")]
47 CapabilityMissing(TpmCap),
48 #[error("operation interrupted by user")]
49 Interrupted,
50 #[error("invalid response")]
51 InvalidResponse,
52
53 #[error("I/O: {0}")]
54 Io(#[from] std::io::Error),
55
56 #[error("marshal: {0}")]
58 Marshal(tpm2_protocol::TpmProtocolError),
59
60 #[error("device not available")]
61 NotAvailable,
62 #[error("operation failed")]
63 OperationFailed,
64 #[error("PCR banks not available")]
65 PcrBanksNotAvailable,
66 #[error("PCR bank size mismatch")]
67 PcrBankSizeMismatch,
68
69 #[error("response mismatch: {0}")]
71 ResponseMismatch(TpmCc),
72
73 #[error("TPM command timed out")]
74 Timeout,
75 #[error("TPM return code: {0}")]
76 TpmRc(TpmRc),
77
78 #[error("unmarshal: {0}")]
80 Unmarshal(tpm2_protocol::TpmProtocolError),
81
82 #[error("unexpected EOF")]
83 UnexpectedEof,
84}
85
86impl From<TpmRc> for TpmDeviceError {
87 fn from(rc: TpmRc) -> Self {
88 Self::TpmRc(rc)
89 }
90}
91
92impl From<nix::Error> for TpmDeviceError {
93 fn from(err: nix::Error) -> Self {
94 Self::Io(std::io::Error::from_raw_os_error(err as i32))
95 }
96}
97
98pub fn with_device<F, T, E>(device: Option<Rc<RefCell<TpmDevice>>>, f: F) -> Result<T, E>
110where
111 F: FnOnce(&mut TpmDevice) -> Result<T, E>,
112 E: From<TpmDeviceError>,
113{
114 let device_rc = device.ok_or(TpmDeviceError::NotAvailable)?;
115 let mut device_guard = device_rc
116 .try_borrow_mut()
117 .map_err(|_| TpmDeviceError::AlreadyBorrowed)?;
118 f(&mut device_guard)
119}
120
121pub struct TpmDeviceBuilder {
123 path: PathBuf,
124 timeout: Duration,
125 interrupted: Box<dyn Fn() -> bool>,
126}
127
128impl Default for TpmDeviceBuilder {
129 fn default() -> Self {
130 Self {
131 path: PathBuf::from("/dev/tpmrm0"),
132 timeout: Duration::from_secs(120),
133 interrupted: Box::new(|| false),
134 }
135 }
136}
137
138impl TpmDeviceBuilder {
139 #[must_use]
141 pub fn with_path<P: AsRef<Path>>(mut self, path: P) -> Self {
142 self.path = path.as_ref().to_path_buf();
143 self
144 }
145
146 #[must_use]
148 pub fn with_timeout(mut self, timeout: Duration) -> Self {
149 self.timeout = timeout;
150 self
151 }
152
153 #[must_use]
155 pub fn with_interrupted<F>(mut self, handler: F) -> Self
156 where
157 F: Fn() -> bool + 'static,
158 {
159 self.interrupted = Box::new(handler);
160 self
161 }
162
163 pub fn build(self) -> Result<TpmDevice, TpmDeviceError> {
170 let file = OpenOptions::new()
171 .read(true)
172 .write(true)
173 .open(&self.path)
174 .map_err(TpmDeviceError::Io)?;
175
176 let fd = file.as_raw_fd();
177 let flags = fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFL)?;
178 let mut oflags = fcntl::OFlag::from_bits_truncate(flags);
179 oflags.insert(fcntl::OFlag::O_NONBLOCK);
180 fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(oflags))?;
181
182 Ok(TpmDevice {
183 file,
184 name_cache: HashMap::new(),
185 interrupted: self.interrupted,
186 timeout: self.timeout,
187 command: Vec::with_capacity(TPM_MAX_COMMAND_SIZE),
188 response: Vec::with_capacity(TPM_MAX_COMMAND_SIZE),
189 })
190 }
191}
192
193pub struct TpmDevice {
194 file: File,
195 name_cache: HashMap<u32, (TpmtPublic, Tpm2bName)>,
196 interrupted: Box<dyn Fn() -> bool>,
197 timeout: Duration,
198 command: Vec<u8>,
199 response: Vec<u8>,
200}
201
202impl std::fmt::Debug for TpmDevice {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 f.debug_struct("Device")
205 .field("file", &self.file)
206 .field("name_cache", &self.name_cache)
207 .field("timeout", &self.timeout)
208 .finish_non_exhaustive()
209 }
210}
211
212impl TpmDevice {
213 const NO_SESSIONS: &'static [TpmsAuthCommand] = &[];
214
215 #[must_use]
217 pub fn builder() -> TpmDeviceBuilder {
218 TpmDeviceBuilder::default()
219 }
220
221 fn receive(&mut self, buf: &mut [u8]) -> Result<usize, TpmDeviceError> {
222 let fd = self.file.as_fd();
223 let mut fds = [PollFd::new(fd, PollFlags::POLLIN)];
224
225 let num_events = match poll(&mut fds, 100u16) {
226 Ok(num) => num,
227 Err(nix::Error::EINTR) => return Ok(0),
228 Err(e) => return Err(e.into()),
229 };
230
231 if num_events == 0 {
232 return Ok(0);
233 }
234
235 let revents = fds[0].revents().unwrap_or(PollFlags::empty());
236
237 if revents.intersects(PollFlags::POLLERR | PollFlags::POLLNVAL) {
238 return Err(TpmDeviceError::UnexpectedEof);
239 }
240
241 if revents.contains(PollFlags::POLLIN) {
242 match self.file.read(buf) {
243 Ok(0) => Err(TpmDeviceError::UnexpectedEof),
244 Ok(n) => Ok(n),
245 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
246 Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok(0),
247 Err(e) => Err(e.into()),
248 }
249 } else if revents.contains(PollFlags::POLLHUP) {
250 Err(TpmDeviceError::UnexpectedEof)
251 } else {
252 Ok(0)
253 }
254 }
255
256 pub fn transmit<C: TpmFrame>(
276 &mut self,
277 command: &C,
278 sessions: &[TpmsAuthCommand],
279 ) -> Result<(TpmResponse, TpmAuthResponses), TpmDeviceError> {
280 self.prepare_command(command, sessions)?;
281 let cc = command.cc();
282
283 self.file.write_all(&self.command)?;
284 self.file.flush()?;
285
286 let start_time = Instant::now();
287 self.response.clear();
288 let mut total_size: Option<usize> = None;
289 let mut temp_buf = [0u8; 1024];
290
291 loop {
292 if (self.interrupted)() {
293 return Err(TpmDeviceError::Interrupted);
294 }
295 if start_time.elapsed() > self.timeout {
296 return Err(TpmDeviceError::Timeout);
297 }
298
299 let n = self.receive(&mut temp_buf)?;
300 if n > 0 {
301 self.response.extend_from_slice(&temp_buf[..n]);
302 }
303
304 if total_size.is_none() && self.response.len() >= 10 {
305 let Ok(size_bytes): Result<[u8; 4], _> = self.response[2..6].try_into() else {
306 return Err(TpmDeviceError::InvalidResponse);
307 };
308 let size = u32::from_be_bytes(size_bytes) as usize;
309 if !(10..={ TPM_MAX_COMMAND_SIZE }).contains(&size) {
310 return Err(TpmDeviceError::InvalidResponse);
311 }
312 total_size = Some(size);
313 }
314
315 if let Some(size) = total_size {
316 if self.response.len() == size {
317 break;
318 }
319 if self.response.len() > size {
320 return Err(TpmDeviceError::InvalidResponse);
321 }
322 }
323 }
324
325 let result = tpm_unmarshal_response(cc, &self.response).map_err(TpmDeviceError::Unmarshal);
326 trace!("{} R: {}", cc, hex::encode(&self.response));
327 Ok(result??)
328 }
329
330 fn prepare_command<C: TpmFrame>(
331 &mut self,
332 command: &C,
333 sessions: &[TpmsAuthCommand],
334 ) -> Result<(), TpmDeviceError> {
335 let cc = command.cc();
336 let tag = if sessions.is_empty() {
337 TpmSt::NoSessions
338 } else {
339 TpmSt::Sessions
340 };
341
342 self.command.resize(TPM_MAX_COMMAND_SIZE, 0);
343
344 let len = {
345 let mut writer = TpmWriter::new(&mut self.command);
346 tpm_marshal_command(command, tag, sessions, &mut writer)
347 .map_err(TpmDeviceError::Marshal)?;
348 writer.len()
349 };
350 self.command.truncate(len);
351
352 trace!("{} C: {}", cc, hex::encode(&self.command));
353 Ok(())
354 }
355
356 fn get_capability<T, F, N>(
365 &mut self,
366 cap: TpmCap,
367 property_start: u32,
368 count: u32,
369 mut extract: F,
370 next_prop: N,
371 ) -> Result<Vec<T>, TpmDeviceError>
372 where
373 T: Copy,
374 F: for<'a> FnMut(&'a TpmuCapabilities) -> Result<&'a [T], TpmDeviceError>,
375 N: Fn(&T) -> u32,
376 {
377 let mut results = Vec::new();
378 let mut prop = property_start;
379 loop {
380 let (more_data, cap_data) =
381 self.get_capability_page(cap, TpmUint32(prop), TpmUint32(count))?;
382 let items: &[T] = extract(&cap_data.data)?;
383 results.extend_from_slice(items);
384
385 if more_data {
386 if let Some(last) = items.last() {
387 prop = next_prop(last);
388 } else {
389 break;
390 }
391 } else {
392 break;
393 }
394 }
395 Ok(results)
396 }
397
398 pub fn fetch_algorithm_properties(&mut self) -> Result<Vec<TpmsAlgProperty>, TpmDeviceError> {
409 self.get_capability(
410 TpmCap::Algs,
411 0,
412 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
413 |caps| match caps {
414 TpmuCapabilities::Algs(algs) => Ok(algs),
415 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Algs)),
416 },
417 |last| last.alg as u32 + 1,
418 )
419 }
420
421 pub fn fetch_handles(&mut self, class: TpmHt) -> Result<Vec<TpmHandle>, TpmDeviceError> {
432 self.get_capability(
433 TpmCap::Handles,
434 (class as u32) << 24,
435 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
436 |caps| match caps {
437 TpmuCapabilities::Handles(handles) => Ok(handles),
438 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Handles)),
439 },
440 |last| last.value() + 1,
441 )
442 .map(|handles| handles.into_iter().collect())
443 }
444
445 pub fn fetch_ecc_curves(&mut self) -> Result<Vec<TpmEccCurve>, TpmDeviceError> {
456 self.get_capability(
457 TpmCap::EccCurves,
458 0,
459 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
460 |caps| match caps {
461 TpmuCapabilities::EccCurves(curves) => Ok(curves),
462 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::EccCurves)),
463 },
464 |last| *last as u32 + 1,
465 )
466 }
467
468 pub fn fetch_pcr_bank_list(&mut self) -> Result<(usize, Vec<TpmAlgId>), TpmDeviceError> {
483 let pcrs: Vec<TpmsPcrSelection> = self.get_capability(
484 TpmCap::Pcrs,
485 0,
486 u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
487 |caps| match caps {
488 TpmuCapabilities::Pcrs(pcrs) => Ok(pcrs),
489 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Pcrs)),
490 },
491 |last| last.hash as u32 + 1,
492 )?;
493
494 if pcrs.is_empty() {
495 return Err(TpmDeviceError::PcrBanksNotAvailable);
496 }
497
498 let mut count = 0;
499 let mut algs = Vec::with_capacity(pcrs.len());
500
501 for bank in pcrs {
502 let next_count = bank.pcr_select.len();
503 if count == 0 {
504 count = next_count;
505 }
506 if next_count != count {
507 return Err(TpmDeviceError::PcrBankSizeMismatch);
508 }
509 algs.push(bank.hash);
510 }
511
512 algs.sort();
513 Ok((count, algs))
514 }
515
516 fn get_capability_page(
526 &mut self,
527 cap: TpmCap,
528 property: TpmUint32,
529 property_count: TpmUint32,
530 ) -> Result<(bool, TpmsCapabilityData), TpmDeviceError> {
531 let cmd = TpmGetCapabilityCommand {
532 cap,
533 property,
534 property_count,
535 handles: [],
536 };
537
538 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
539 let TpmGetCapabilityResponse {
540 more_data,
541 capability_data,
542 handles: [],
543 } = resp
544 .GetCapability()
545 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::GetCapability))?;
546
547 Ok((more_data.into(), capability_data))
548 }
549
550 pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<TpmUint32, TpmDeviceError> {
559 let (_, cap_data) = self.get_capability_page(
560 TpmCap::TpmProperties,
561 TpmUint32(property as u32),
562 TpmUint32(1),
563 )?;
564
565 let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
566 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
567 };
568
569 let Some(prop) = props.iter().find(|prop| prop.property == property) else {
570 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
571 };
572
573 Ok(prop.value)
574 }
575
576 pub fn read_public(
585 &mut self,
586 handle: TpmHandle,
587 ) -> Result<(TpmtPublic, Tpm2bName), TpmDeviceError> {
588 if let Some(cached) = self.name_cache.get(&handle.0) {
589 return Ok(cached.clone());
590 }
591
592 let cmd = TpmReadPublicCommand { handles: [handle] };
593 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
594
595 let read_public_resp = resp
596 .ReadPublic()
597 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
598
599 let public = read_public_resp.out_public.inner;
600 let name = read_public_resp.name;
601
602 self.name_cache.insert(handle.0, (public.clone(), name));
603 Ok((public, name))
604 }
605
606 pub fn find_persistent(
618 &mut self,
619 target_name: &Tpm2bName,
620 ) -> Result<Option<TpmHandle>, TpmDeviceError> {
621 for handle in self.fetch_handles(TpmHt::Persistent)? {
622 match self.read_public(handle) {
623 Ok((_, name)) => {
624 if name == *target_name {
625 return Ok(Some(handle));
626 }
627 }
628 Err(TpmDeviceError::TpmRc(rc)) => {
629 let base = rc.base();
630 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
631 continue;
632 }
633 return Err(TpmDeviceError::TpmRc(rc));
634 }
635 Err(e) => return Err(e),
636 }
637 }
638 Ok(None)
639 }
640
641 pub fn save_context(&mut self, save_handle: TpmHandle) -> Result<TpmsContext, TpmDeviceError> {
650 let cmd = TpmContextSaveCommand {
651 handles: [save_handle],
652 };
653 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
654 let save_resp = resp
655 .ContextSave()
656 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextSave))?;
657 Ok(save_resp.context)
658 }
659
660 pub fn load_context(&mut self, context: TpmsContext) -> Result<TpmHandle, TpmDeviceError> {
669 let cmd = TpmContextLoadCommand {
670 context,
671 handles: [],
672 };
673 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
674 let resp_inner = resp
675 .ContextLoad()
676 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
677 Ok(resp_inner.handles[0])
678 }
679
680 pub fn flush_context(&mut self, handle: TpmHandle) -> Result<(), TpmDeviceError> {
688 self.name_cache.remove(&handle.0);
689 let cmd = TpmFlushContextCommand {
690 flush_handle: handle,
691 handles: [],
692 };
693 self.transmit(&cmd, Self::NO_SESSIONS)?;
694 Ok(())
695 }
696
697 pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), TpmDeviceError> {
709 match self.load_context(context) {
710 Ok(handle) => self.flush_context(handle),
711 Err(TpmDeviceError::TpmRc(rc)) => {
712 let base = rc.base();
713 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
714 Ok(())
715 } else {
716 Err(TpmDeviceError::TpmRc(rc))
717 }
718 }
719 Err(e) => Err(e),
720 }
721 }
722
723 pub fn refresh_key(&mut self, context: TpmsContext) -> Result<bool, TpmDeviceError> {
735 match self.load_context(context) {
736 Ok(handle) => match self.flush_context(handle) {
737 Ok(()) => Ok(true),
738 Err(e) => Err(e),
739 },
740 Err(TpmDeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => Ok(false),
741 Err(e) => Err(e),
742 }
743 }
744}