1#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8use nix::{
9 fcntl,
10 poll::{poll, PollFd, PollFlags},
11};
12use std::{
13 cell::RefCell,
14 collections::HashMap,
15 fs::{File, OpenOptions},
16 io::{Read, Write},
17 num::TryFromIntError,
18 os::fd::{AsFd, AsRawFd},
19 path::{Path, PathBuf},
20 rc::Rc,
21 time::{Duration, Instant},
22};
23
24use thiserror::Error;
25use tpm2_protocol::{
26 constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
27 data::{
28 Tpm2bName, TpmCap, TpmCc, TpmEccCurve, TpmHt, TpmPt, TpmRc, TpmRcBase, TpmSt,
29 TpmsAlgProperty, TpmsAuthCommand, TpmsCapabilityData, TpmsContext, TpmsPcrSelection,
30 TpmtPublic, TpmuCapabilities,
31 },
32 frame::{
33 tpm_marshal_command, tpm_unmarshal_response, TpmAuthResponses, TpmContextLoadCommand,
34 TpmContextSaveCommand, TpmEvictControlCommand, TpmFlushContextCommand, TpmFrame,
35 TpmGetCapabilityCommand, TpmGetCapabilityResponse, TpmReadPublicCommand, TpmResponse,
36 },
37 TpmHandle, TpmWriter,
38};
39use tracing::trace;
40
41#[derive(Debug, Error)]
43pub enum TpmDeviceError {
44 #[error("device is already borrowed")]
45 AlreadyBorrowed,
46 #[error("capability not found: {0}")]
47 CapabilityMissing(TpmCap),
48 #[error("operation interrupted by user")]
49 Interrupted,
50 #[error("invalid response")]
51 InvalidResponse,
52 #[error("device not available")]
53 NotAvailable,
54
55 #[error("marshal: {0}")]
57 Marshal(tpm2_protocol::TpmProtocolError),
58
59 #[error("unmarshal: {0}")]
61 Unmarshal(tpm2_protocol::TpmProtocolError),
62
63 #[error("response mismatch: {0}")]
64 ResponseMismatch(TpmCc),
65 #[error("TPM command timed out")]
66 Timeout,
67 #[error("unexpected EOF")]
68 UnexpectedEof,
69 #[error("int decode: {0}")]
70 IntDecode(#[from] TryFromIntError),
71 #[error("I/O: {0}")]
72 Io(#[from] std::io::Error),
73 #[error("syscall: {0}")]
74 Nix(#[from] nix::Error),
75 #[error("TPM return code: {0}")]
76 TpmRc(TpmRc),
77}
78
79impl From<TpmRc> for TpmDeviceError {
80 fn from(rc: TpmRc) -> Self {
81 Self::TpmRc(rc)
82 }
83}
84
85pub fn with_device<F, T, E>(device: Option<Rc<RefCell<TpmDevice>>>, f: F) -> Result<T, E>
97where
98 F: FnOnce(&mut TpmDevice) -> Result<T, E>,
99 E: From<TpmDeviceError>,
100{
101 let device_rc = device.ok_or(TpmDeviceError::NotAvailable)?;
102 let mut device_guard = device_rc
103 .try_borrow_mut()
104 .map_err(|_| TpmDeviceError::AlreadyBorrowed)?;
105 f(&mut device_guard)
106}
107
108pub struct TpmDeviceBuilder {
110 path: PathBuf,
111 timeout: Duration,
112 interrupted: Box<dyn Fn() -> bool>,
113}
114
115impl Default for TpmDeviceBuilder {
116 fn default() -> Self {
117 Self {
118 path: PathBuf::from("/dev/tpmrm0"),
119 timeout: Duration::from_secs(120),
120 interrupted: Box::new(|| false),
121 }
122 }
123}
124
125impl TpmDeviceBuilder {
126 #[must_use]
128 pub fn with_path<P: AsRef<Path>>(mut self, path: P) -> Self {
129 self.path = path.as_ref().to_path_buf();
130 self
131 }
132
133 #[must_use]
135 pub fn with_timeout(mut self, timeout: Duration) -> Self {
136 self.timeout = timeout;
137 self
138 }
139
140 #[must_use]
142 pub fn with_interrupted<F>(mut self, handler: F) -> Self
143 where
144 F: Fn() -> bool + 'static,
145 {
146 self.interrupted = Box::new(handler);
147 self
148 }
149
150 pub fn build(self) -> Result<TpmDevice, TpmDeviceError> {
158 let file = OpenOptions::new()
159 .read(true)
160 .write(true)
161 .open(&self.path)
162 .map_err(TpmDeviceError::Io)?;
163
164 let fd = file.as_raw_fd();
165 let flags = fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFL)?;
166 let mut oflags = fcntl::OFlag::from_bits_truncate(flags);
167 oflags.insert(fcntl::OFlag::O_NONBLOCK);
168 fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(oflags))?;
169
170 Ok(TpmDevice {
171 file,
172 name_cache: HashMap::new(),
173 interrupted: self.interrupted,
174 timeout: self.timeout,
175 command: Vec::with_capacity(TPM_MAX_COMMAND_SIZE as usize),
176 response: vec![0; TPM_MAX_COMMAND_SIZE as usize],
177 })
178 }
179}
180
181pub struct TpmDevice {
182 file: File,
183 name_cache: HashMap<u32, (TpmtPublic, Tpm2bName)>,
184 interrupted: Box<dyn Fn() -> bool>,
185 timeout: Duration,
186 command: Vec<u8>,
187 response: Vec<u8>,
188}
189
190impl std::fmt::Debug for TpmDevice {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 f.debug_struct("Device")
193 .field("file", &self.file)
194 .field("name_cache", &self.name_cache)
195 .field("timeout", &self.timeout)
196 .finish_non_exhaustive()
197 }
198}
199
200impl TpmDevice {
201 const NO_SESSIONS: &'static [TpmsAuthCommand] = &[];
202
203 #[must_use]
205 pub fn builder() -> TpmDeviceBuilder {
206 TpmDeviceBuilder::default()
207 }
208
209 fn receive(file: &mut File, buf: &mut [u8]) -> Result<usize, TpmDeviceError> {
210 let fd = file.as_fd();
211 let mut fds = [PollFd::new(fd, PollFlags::POLLIN)];
212
213 let num_events = match poll(&mut fds, 100u16) {
214 Ok(num) => num,
215 Err(nix::Error::EINTR) => return Ok(0),
216 Err(e) => return Err(e.into()),
217 };
218
219 if num_events == 0 {
220 return Ok(0);
221 }
222
223 let revents = fds[0].revents().unwrap_or(PollFlags::empty());
224
225 if revents.intersects(PollFlags::POLLERR | PollFlags::POLLNVAL) {
226 return Err(TpmDeviceError::UnexpectedEof);
227 }
228
229 if revents.contains(PollFlags::POLLIN) {
230 match file.read(buf) {
231 Ok(0) => Err(TpmDeviceError::UnexpectedEof),
232 Ok(n) => Ok(n),
233 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
234 Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok(0),
235 Err(e) => Err(e.into()),
236 }
237 } else if revents.contains(PollFlags::POLLHUP) {
238 Err(TpmDeviceError::UnexpectedEof)
239 } else {
240 Ok(0)
241 }
242 }
243
244 pub fn transmit<C: TpmFrame>(
265 &mut self,
266 command: &C,
267 sessions: &[TpmsAuthCommand],
268 ) -> Result<(TpmResponse, TpmAuthResponses), TpmDeviceError> {
269 self.prepare_command(command, sessions)?;
270 let cc = command.cc();
271
272 self.file.write_all(&self.command)?;
273 self.file.flush()?;
274
275 let start_time = Instant::now();
276 let mut read_offset = 0;
277 let mut total_size: Option<usize> = None;
278
279 loop {
280 if (self.interrupted)() {
281 return Err(TpmDeviceError::Interrupted);
282 }
283 if start_time.elapsed() > self.timeout {
284 return Err(TpmDeviceError::Timeout);
285 }
286
287 let n = Self::receive(&mut self.file, &mut self.response[read_offset..])?;
288 if n > 0 {
289 read_offset += n;
290 }
291
292 if total_size.is_none() && read_offset >= 10 {
293 let Ok(size_bytes): Result<[u8; 4], _> = self.response[2..6].try_into() else {
294 return Err(TpmDeviceError::InvalidResponse);
295 };
296 let size = u32::from_be_bytes(size_bytes) as usize;
297 if !(10..=TPM_MAX_COMMAND_SIZE as usize).contains(&size) {
298 return Err(TpmDeviceError::InvalidResponse);
299 }
300 total_size = Some(size);
301 }
302
303 if let Some(size) = total_size {
304 if read_offset == size {
305 break;
306 }
307 if read_offset > size {
308 return Err(TpmDeviceError::InvalidResponse);
309 }
310 }
311 }
312
313 let response_data = &self.response[..read_offset];
314 let result = tpm_unmarshal_response(cc, response_data).map_err(TpmDeviceError::Unmarshal);
315 trace!("{} R: {}", cc, hex::encode(response_data));
316 Ok(result??)
317 }
318
319 fn prepare_command<C: TpmFrame>(
320 &mut self,
321 command: &C,
322 sessions: &[TpmsAuthCommand],
323 ) -> Result<(), TpmDeviceError> {
324 let cc = command.cc();
325 let tag = if sessions.is_empty() {
326 TpmSt::NoSessions
327 } else {
328 TpmSt::Sessions
329 };
330
331 self.command.clear();
332 let mut writer = TpmWriter::new(&mut self.command);
333 tpm_marshal_command(command, tag, sessions, &mut writer)
334 .map_err(TpmDeviceError::Marshal)?;
335
336 trace!("{} C: {}", cc, hex::encode(&self.command));
337 Ok(())
338 }
339
340 fn get_capability<T, F, N>(
349 &mut self,
350 cap: TpmCap,
351 property_start: u32,
352 count: u32,
353 mut extract: F,
354 next_prop: N,
355 ) -> Result<Vec<T>, TpmDeviceError>
356 where
357 T: Copy,
358 F: for<'a> FnMut(&'a TpmuCapabilities) -> Result<&'a [T], TpmDeviceError>,
359 N: Fn(&T) -> u32,
360 {
361 let mut results = Vec::new();
362 let mut prop = property_start;
363 loop {
364 let (more_data, cap_data) = self.get_capability_page(cap, prop, count)?;
365 let items: &[T] = extract(&cap_data.data)?;
366 results.extend_from_slice(items);
367
368 if more_data {
369 if let Some(last) = items.last() {
370 prop = next_prop(last);
371 } else {
372 break;
373 }
374 } else {
375 break;
376 }
377 }
378 Ok(results)
379 }
380
381 pub fn fetch_algorithm_properties(&mut self) -> Result<Vec<TpmsAlgProperty>, TpmDeviceError> {
392 self.get_capability(
393 TpmCap::Algs,
394 0,
395 u32::try_from(MAX_HANDLES)?,
396 |caps| match caps {
397 TpmuCapabilities::Algs(algs) => Ok(algs),
398 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Algs)),
399 },
400 |last| last.alg as u32 + 1,
401 )
402 }
403
404 pub fn fetch_handles(&mut self, class: TpmHt) -> Result<Vec<TpmHandle>, TpmDeviceError> {
415 self.get_capability(
416 TpmCap::Handles,
417 (class as u32) << 24,
418 u32::try_from(MAX_HANDLES)?,
419 |caps| match caps {
420 TpmuCapabilities::Handles(handles) => Ok(handles),
421 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Handles)),
422 },
423 |last| *last + 1,
424 )
425 .map(|handles| handles.into_iter().map(TpmHandle).collect())
426 }
427
428 pub fn fetch_ecc_curves(&mut self) -> Result<Vec<TpmEccCurve>, TpmDeviceError> {
439 self.get_capability(
440 TpmCap::EccCurves,
441 0,
442 u32::try_from(MAX_HANDLES)?,
443 |caps| match caps {
444 TpmuCapabilities::EccCurves(curves) => Ok(curves),
445 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::EccCurves)),
446 },
447 |last| *last as u32 + 1,
448 )
449 }
450
451 pub fn fetch_pcr_banks(&mut self) -> Result<Vec<TpmsPcrSelection>, TpmDeviceError> {
462 self.get_capability(
463 TpmCap::Pcrs,
464 0,
465 u32::try_from(MAX_HANDLES)?,
466 |caps| match caps {
467 TpmuCapabilities::Pcrs(pcrs) => Ok(pcrs),
468 _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Pcrs)),
469 },
470 |last| last.hash as u32 + 1,
471 )
472 }
473
474 fn get_capability_page(
484 &mut self,
485 cap: TpmCap,
486 property: u32,
487 count: u32,
488 ) -> Result<(bool, TpmsCapabilityData), TpmDeviceError> {
489 let cmd = TpmGetCapabilityCommand {
490 cap,
491 property,
492 property_count: count,
493 handles: [],
494 };
495
496 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
497 let TpmGetCapabilityResponse {
498 more_data,
499 capability_data,
500 handles: [],
501 } = resp
502 .GetCapability()
503 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::GetCapability))?;
504
505 Ok((more_data.into(), capability_data))
506 }
507
508 pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<u32, TpmDeviceError> {
517 let (_, cap_data) = self.get_capability_page(TpmCap::TpmProperties, property as u32, 1)?;
518
519 let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
520 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
521 };
522
523 let Some(prop) = props.first() else {
524 return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
525 };
526
527 Ok(prop.value)
528 }
529
530 pub fn read_public(
539 &mut self,
540 handle: TpmHandle,
541 ) -> Result<(TpmtPublic, Tpm2bName), TpmDeviceError> {
542 if let Some(cached) = self.name_cache.get(&handle.0) {
543 return Ok(cached.clone());
544 }
545
546 let cmd = TpmReadPublicCommand { handles: [handle] };
547 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
548
549 let read_public_resp = resp
550 .ReadPublic()
551 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
552
553 let public = read_public_resp.out_public.inner;
554 let name = read_public_resp.name;
555
556 self.name_cache.insert(handle.0, (public.clone(), name));
557 Ok((public, name))
558 }
559
560 pub fn find_persistent(
572 &mut self,
573 target_name: &Tpm2bName,
574 ) -> Result<Option<TpmHandle>, TpmDeviceError> {
575 for handle in self.fetch_handles(TpmHt::Persistent)? {
576 match self.read_public(handle) {
577 Ok((_, name)) => {
578 if name == *target_name {
579 return Ok(Some(handle));
580 }
581 }
582 Err(TpmDeviceError::TpmRc(rc)) => {
583 let base = rc.base();
584 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
585 continue;
586 }
587 return Err(TpmDeviceError::TpmRc(rc));
588 }
589 Err(e) => return Err(e),
590 }
591 }
592 Ok(None)
593 }
594
595 pub fn save_context(&mut self, save_handle: TpmHandle) -> Result<TpmsContext, TpmDeviceError> {
604 let cmd = TpmContextSaveCommand {
605 handles: [save_handle],
606 };
607 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
608 let save_resp = resp
609 .ContextSave()
610 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextSave))?;
611 Ok(save_resp.context)
612 }
613
614 pub fn load_context(&mut self, context: TpmsContext) -> Result<TpmHandle, TpmDeviceError> {
623 let cmd = TpmContextLoadCommand {
624 context,
625 handles: [],
626 };
627 let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
628 let resp_inner = resp
629 .ContextLoad()
630 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
631 Ok(resp_inner.handles[0])
632 }
633
634 pub fn flush_context(&mut self, handle: TpmHandle) -> Result<(), TpmDeviceError> {
642 self.name_cache.remove(&handle.0);
643 let cmd = TpmFlushContextCommand {
644 flush_handle: handle,
645 handles: [],
646 };
647 self.transmit(&cmd, Self::NO_SESSIONS)?;
648 Ok(())
649 }
650
651 pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), TpmDeviceError> {
663 match self.load_context(context) {
664 Ok(handle) => self.flush_context(handle),
665 Err(TpmDeviceError::TpmRc(rc)) => {
666 let base = rc.base();
667 if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
668 Ok(())
669 } else {
670 Err(TpmDeviceError::TpmRc(rc))
671 }
672 }
673 Err(e) => Err(e),
674 }
675 }
676
677 pub fn evict_control(
686 &mut self,
687 auth: TpmHandle,
688 object_handle: TpmHandle,
689 persistent_handle: TpmHandle,
690 sessions: &[TpmsAuthCommand],
691 ) -> Result<(), TpmDeviceError> {
692 let cmd = TpmEvictControlCommand {
693 handles: [auth, object_handle],
694 persistent_handle,
695 };
696
697 let (resp, _) = self.transmit(&cmd, sessions)?;
698
699 resp.EvictControl()
700 .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::EvictControl))?;
701
702 Ok(())
703 }
704
705 pub fn refresh_key(&mut self, context: TpmsContext) -> Result<bool, TpmDeviceError> {
717 match self.load_context(context) {
718 Ok(handle) => match self.flush_context(handle) {
719 Ok(()) => Ok(true),
720 Err(e) => Err(e),
721 },
722 Err(TpmDeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => Ok(false),
723 Err(e) => Err(e),
724 }
725 }
726}