tpm2_device/
lib.rs

1// SPDX-License-Identifier: GPL-3-0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8use nix::{
9    fcntl,
10    poll::{poll, PollFd, PollFlags},
11};
12use rand::{thread_rng, RngCore};
13use std::{
14    cell::RefCell,
15    collections::HashMap,
16    fs::{File, OpenOptions},
17    io::{Read, Write},
18    os::fd::{AsFd, AsRawFd},
19    path::{Path, PathBuf},
20    rc::Rc,
21    time::{Duration, Instant},
22};
23
24use thiserror::Error;
25use tpm2_crypto::TpmHash;
26use tpm2_protocol::{
27    basic::{TpmHandle, TpmUint32},
28    constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
29    data::{
30        Tpm2bEncryptedSecret, Tpm2bName, Tpm2bNonce, TpmAlgId, TpmCap, TpmCc, TpmEccCurve, TpmHt,
31        TpmPt, TpmRc, TpmRcBase, TpmRh, TpmSe, TpmSt, TpmaSession, TpmsAlgProperty,
32        TpmsAuthCommand, TpmsCapabilityData, TpmsContext, TpmsPcrSelect, TpmsPcrSelection,
33        TpmtPublic, TpmtSymDefObject, TpmuCapabilities,
34    },
35    frame::{
36        tpm_marshal_command, tpm_unmarshal_response, TpmAuthCommands, TpmAuthResponses, TpmCommand,
37        TpmContextLoadCommand, TpmContextSaveCommand, TpmFlushContextCommand, TpmFrame,
38        TpmGetCapabilityCommand, TpmGetCapabilityResponse, TpmReadPublicCommand, TpmResponse,
39        TpmStartAuthSessionCommand,
40    },
41    TpmWriter,
42};
43use tracing::{debug, trace};
44
45/// Errors that can occur when talking to a TPM device.
46#[derive(Debug, Error)]
47pub enum TpmDeviceError {
48    #[error("device is already borrowed")]
49    AlreadyBorrowed,
50    #[error("capability not found: {0}")]
51    CapabilityMissing(TpmCap),
52    #[error("operation interrupted by user")]
53    Interrupted,
54    #[error("invalid response")]
55    InvalidResponse,
56
57    #[error("I/O: {0}")]
58    Io(#[from] std::io::Error),
59
60    #[error("malformed data")]
61    MalformedData,
62
63    /// Marshaling a TPM protocol encoded object failed.
64    #[error("marshal: {0}")]
65    Marshal(tpm2_protocol::TpmProtocolError),
66
67    #[error("device not available")]
68    NotAvailable,
69    #[error("operation failed")]
70    OperationFailed,
71    #[error("out of memory")]
72    OutOfMemory,
73    #[error("PCR banks not available")]
74    PcrBanksNotAvailable,
75    #[error("PCR bank selection mismatch")]
76    PcrBankSelectionMismatch,
77
78    /// The TPM response did not match the expected command code.
79    #[error("response mismatch: {0}")]
80    ResponseMismatch(TpmCc),
81
82    #[error("TPM command timed out")]
83    Timeout,
84    #[error("TPM return code: {0}")]
85    TpmRc(TpmRc),
86
87    /// Unmarshaling a TPM protocol encoded object failed.
88    #[error("unmarshal: {0}")]
89    Unmarshal(tpm2_protocol::TpmProtocolError),
90
91    #[error("unexpected EOF")]
92    UnexpectedEof,
93}
94
95impl From<TpmRc> for TpmDeviceError {
96    fn from(rc: TpmRc) -> Self {
97        Self::TpmRc(rc)
98    }
99}
100
101impl From<nix::Error> for TpmDeviceError {
102    fn from(err: nix::Error) -> Self {
103        Self::Io(std::io::Error::from_raw_os_error(err as i32))
104    }
105}
106
107/// Executes a closure with a mutable reference to a `TpmDevice`.
108///
109/// This helper function centralizes the boilerplate for safely acquiring a
110/// mutable borrow of a `TpmDevice` from the shared `Rc<RefCell<...>>`.
111///
112/// # Errors
113///
114/// Returns [`NotAvailable`](crate::TpmDeviceError::NotAvailable) when no device
115/// is present and [`AlreadyBorrowed`](crate::TpmDeviceError::AlreadyBorrowed)
116/// when the device is already mutably borrowed, both converted into the caller's
117/// error type `E`. Propagates any error returned by the closure `f`.
118pub fn with_device<F, T, E>(device: Option<Rc<RefCell<TpmDevice>>>, f: F) -> Result<T, E>
119where
120    F: FnOnce(&mut TpmDevice) -> Result<T, E>,
121    E: From<TpmDeviceError>,
122{
123    let device_rc = device.ok_or(TpmDeviceError::NotAvailable)?;
124    let mut device_guard = device_rc
125        .try_borrow_mut()
126        .map_err(|_| TpmDeviceError::AlreadyBorrowed)?;
127    f(&mut device_guard)
128}
129
130/// A builder for constructing a `TpmDevice`.
131pub struct TpmDeviceBuilder {
132    path: PathBuf,
133    timeout: Duration,
134    interrupted: Box<dyn Fn() -> bool>,
135}
136
137impl Default for TpmDeviceBuilder {
138    fn default() -> Self {
139        Self {
140            path: PathBuf::from("/dev/tpmrm0"),
141            timeout: Duration::from_secs(120),
142            interrupted: Box::new(|| false),
143        }
144    }
145}
146
147impl TpmDeviceBuilder {
148    /// Sets the device file path.
149    #[must_use]
150    pub fn with_path<P: AsRef<Path>>(mut self, path: P) -> Self {
151        self.path = path.as_ref().to_path_buf();
152        self
153    }
154
155    /// Sets the operation timeout.
156    #[must_use]
157    pub fn with_timeout(mut self, timeout: Duration) -> Self {
158        self.timeout = timeout;
159        self
160    }
161
162    /// Sets the interruption check callback.
163    #[must_use]
164    pub fn with_interrupted<F>(mut self, handler: F) -> Self
165    where
166        F: Fn() -> bool + 'static,
167    {
168        self.interrupted = Box::new(handler);
169        self
170    }
171
172    /// Opens the TPM device file and constructs the `TpmDevice`.
173    ///
174    /// # Errors
175    ///
176    /// Returns [`Io`](crate::TpmDeviceError::Io) when the device file cannot be
177    /// opened or when configuring the file descriptor flags fails.
178    pub fn build(self) -> Result<TpmDevice, TpmDeviceError> {
179        let file = OpenOptions::new()
180            .read(true)
181            .write(true)
182            .open(&self.path)
183            .map_err(TpmDeviceError::Io)?;
184
185        let fd = file.as_raw_fd();
186        let flags = fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFL)?;
187        let mut oflags = fcntl::OFlag::from_bits_truncate(flags);
188        oflags.insert(fcntl::OFlag::O_NONBLOCK);
189        fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(oflags))?;
190
191        Ok(TpmDevice {
192            file,
193            name_cache: HashMap::new(),
194            interrupted: self.interrupted,
195            timeout: self.timeout,
196            command: Vec::with_capacity(TPM_MAX_COMMAND_SIZE),
197            response: Vec::with_capacity(TPM_MAX_COMMAND_SIZE),
198        })
199    }
200}
201
202pub struct TpmDevice {
203    file: File,
204    name_cache: HashMap<u32, (TpmtPublic, Tpm2bName)>,
205    interrupted: Box<dyn Fn() -> bool>,
206    timeout: Duration,
207    command: Vec<u8>,
208    response: Vec<u8>,
209}
210
211impl std::fmt::Debug for TpmDevice {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        f.debug_struct("Device")
214            .field("file", &self.file)
215            .field("name_cache", &self.name_cache)
216            .field("timeout", &self.timeout)
217            .finish_non_exhaustive()
218    }
219}
220
221impl TpmDevice {
222    const NO_SESSIONS: &'static [TpmsAuthCommand] = &[];
223
224    /// Creates a new builder for `TpmDevice`.
225    #[must_use]
226    pub fn builder() -> TpmDeviceBuilder {
227        TpmDeviceBuilder::default()
228    }
229
230    fn receive(&mut self, buf: &mut [u8]) -> Result<usize, TpmDeviceError> {
231        let fd = self.file.as_fd();
232        let mut fds = [PollFd::new(fd, PollFlags::POLLIN)];
233
234        let num_events = match poll(&mut fds, 100u16) {
235            Ok(num) => num,
236            Err(nix::Error::EINTR) => return Ok(0),
237            Err(e) => return Err(e.into()),
238        };
239
240        if num_events == 0 {
241            return Ok(0);
242        }
243
244        let revents = fds[0].revents().unwrap_or(PollFlags::empty());
245
246        if revents.intersects(PollFlags::POLLERR | PollFlags::POLLNVAL) {
247            return Err(TpmDeviceError::UnexpectedEof);
248        }
249
250        if revents.contains(PollFlags::POLLIN) {
251            match self.file.read(buf) {
252                Ok(0) => Err(TpmDeviceError::UnexpectedEof),
253                Ok(n) => Ok(n),
254                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
255                Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok(0),
256                Err(e) => Err(e.into()),
257            }
258        } else if revents.contains(PollFlags::POLLHUP) {
259            Err(TpmDeviceError::UnexpectedEof)
260        } else {
261            Ok(0)
262        }
263    }
264
265    /// Performs the whole TPM command transmission process.
266    ///
267    /// # Errors
268    ///
269    /// Returns [`Interrupted`](crate::TpmDeviceError::Interrupted) when the
270    /// interrupt callback requests cancellation.
271    /// Returns [`Timeout`](crate::TpmDeviceError::Timeout) when the TPM does
272    /// not respond within the configured timeout.
273    /// Returns [`Io`](crate::TpmDeviceError::Io) when a write, flush, or read
274    /// operation on the device file fails, or when polling the device file
275    /// descriptor fails.
276    /// Returns [`InvalidResponse`](crate::TpmDeviceError::InvalidResponse) or
277    /// [`UnexpectedEof`](crate::TpmDeviceError::UnexpectedEof) when the TPM
278    /// reply is malformed, truncated, or longer than the announced size.
279    /// Returns [`Marshal`](crate::TpmDeviceError::Marshal) or
280    /// [`Unmarshal`](crate::TpmDeviceError::Unmarshal) when encoding the
281    /// command or decoding the response fails.
282    /// Returns [`TpmRc`](crate::TpmDeviceError::TpmRc) when the TPM returns an
283    /// error code.
284    pub fn transmit<C: TpmFrame>(
285        &mut self,
286        command: &C,
287        sessions: &[TpmsAuthCommand],
288    ) -> Result<(TpmResponse, TpmAuthResponses), TpmDeviceError> {
289        self.prepare_command(command, sessions)?;
290        let cc = command.cc();
291
292        self.file.write_all(&self.command)?;
293        self.file.flush()?;
294
295        let start_time = Instant::now();
296        self.response.clear();
297        let mut total_size: Option<usize> = None;
298        let mut temp_buf = [0u8; 1024];
299
300        loop {
301            if (self.interrupted)() {
302                return Err(TpmDeviceError::Interrupted);
303            }
304            if start_time.elapsed() > self.timeout {
305                return Err(TpmDeviceError::Timeout);
306            }
307
308            let n = self.receive(&mut temp_buf)?;
309            if n > 0 {
310                self.response.extend_from_slice(&temp_buf[..n]);
311            }
312
313            if total_size.is_none() && self.response.len() >= 10 {
314                let Ok(size_bytes): Result<[u8; 4], _> = self.response[2..6].try_into() else {
315                    return Err(TpmDeviceError::InvalidResponse);
316                };
317                let size = u32::from_be_bytes(size_bytes) as usize;
318                if !(10..={ TPM_MAX_COMMAND_SIZE }).contains(&size) {
319                    return Err(TpmDeviceError::InvalidResponse);
320                }
321                total_size = Some(size);
322            }
323
324            if let Some(size) = total_size {
325                if self.response.len() == size {
326                    break;
327                }
328                if self.response.len() > size {
329                    return Err(TpmDeviceError::InvalidResponse);
330                }
331            }
332        }
333
334        let result = tpm_unmarshal_response(cc, &self.response).map_err(TpmDeviceError::Unmarshal);
335        trace!("{} R: {}", cc, hex::encode(&self.response));
336        Ok(result??)
337    }
338
339    fn prepare_command<C: TpmFrame>(
340        &mut self,
341        command: &C,
342        sessions: &[TpmsAuthCommand],
343    ) -> Result<(), TpmDeviceError> {
344        let cc = command.cc();
345        let tag = if sessions.is_empty() {
346            TpmSt::NoSessions
347        } else {
348            TpmSt::Sessions
349        };
350
351        self.command.resize(TPM_MAX_COMMAND_SIZE, 0);
352
353        let len = {
354            let mut writer = TpmWriter::new(&mut self.command);
355            tpm_marshal_command(command, tag, sessions, &mut writer)
356                .map_err(TpmDeviceError::Marshal)?;
357            writer.len()
358        };
359        self.command.truncate(len);
360
361        trace!("{} C: {}", cc, hex::encode(&self.command));
362        Ok(())
363    }
364
365    /// Fetches a complete list of capabilities from the TPM, handling
366    /// pagination.
367    ///
368    /// # Errors
369    ///
370    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) returned by
371    /// [`get_capability_page`](TpmDevice::get_capability_page) or by the
372    /// `extract` closure.
373    fn get_capability<T, F, N>(
374        &mut self,
375        cap: TpmCap,
376        property_start: u32,
377        count: u32,
378        mut extract: F,
379        next_prop: N,
380    ) -> Result<Vec<T>, TpmDeviceError>
381    where
382        T: Copy,
383        F: for<'a> FnMut(&'a TpmuCapabilities) -> Result<&'a [T], TpmDeviceError>,
384        N: Fn(&T) -> u32,
385    {
386        let mut results = Vec::new();
387        let mut prop = property_start;
388        loop {
389            let (more_data, cap_data) =
390                self.get_capability_page(cap, TpmUint32(prop), TpmUint32(count))?;
391            let items: &[T] = extract(&cap_data.data)?;
392            results.extend_from_slice(items);
393
394            if more_data {
395                if let Some(last) = items.last() {
396                    prop = next_prop(last);
397                } else {
398                    break;
399                }
400            } else {
401                break;
402            }
403        }
404        Ok(results)
405    }
406
407    /// Retrieves all algorithm properties supported by the TPM.
408    ///
409    /// # Errors
410    ///
411    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
412    /// the handle count cannot be represented as `u32`. Propagates any
413    /// [`TpmDeviceError`](crate::TpmDeviceError) from
414    /// [`get_capability`](TpmDevice::get_capability), including
415    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
416    /// TPM does not report algorithm properties.
417    pub fn fetch_algorithm_properties(&mut self) -> Result<Vec<TpmsAlgProperty>, TpmDeviceError> {
418        self.get_capability(
419            TpmCap::Algs,
420            0,
421            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
422            |caps| match caps {
423                TpmuCapabilities::Algs(algs) => Ok(algs),
424                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Algs)),
425            },
426            |last| last.alg as u32 + 1,
427        )
428    }
429
430    /// Retrieves all handles of a specific type from the TPM.
431    ///
432    /// # Errors
433    ///
434    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
435    /// the handle count cannot be represented as `u32`. Propagates any
436    /// [`TpmDeviceError`](crate::TpmDeviceError) from
437    /// [`get_capability`](TpmDevice::get_capability), including
438    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
439    /// TPM does not report handles of the requested class.
440    pub fn fetch_handles(&mut self, class: TpmHt) -> Result<Vec<TpmHandle>, TpmDeviceError> {
441        self.get_capability(
442            TpmCap::Handles,
443            (class as u32) << 24,
444            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
445            |caps| match caps {
446                TpmuCapabilities::Handles(handles) => Ok(handles),
447                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Handles)),
448            },
449            |last| last.value() + 1,
450        )
451        .map(|handles| handles.into_iter().collect())
452    }
453
454    /// Retrieves all available ECC curves supported by the TPM.
455    ///
456    /// # Errors
457    ///
458    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
459    /// the handle count cannot be represented as `u32`. Propagates any
460    /// [`TpmDeviceError`](crate::TpmDeviceError) from
461    /// [`get_capability`](TpmDevice::get_capability), including
462    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
463    /// TPM does not report ECC curves.
464    pub fn fetch_ecc_curves(&mut self) -> Result<Vec<TpmEccCurve>, TpmDeviceError> {
465        self.get_capability(
466            TpmCap::EccCurves,
467            0,
468            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
469            |caps| match caps {
470                TpmuCapabilities::EccCurves(curves) => Ok(curves),
471                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::EccCurves)),
472            },
473            |last| *last as u32 + 1,
474        )
475    }
476
477    /// Retrieves the list of active PCR banks and the bank selection mask.
478    ///
479    /// # Errors
480    ///
481    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
482    /// the handle count cannot be represented as `u32`. Propagates any
483    /// [`TpmDeviceError`](crate::TpmDeviceError) from
484    /// [`get_capability`](TpmDevice::get_capability), including
485    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
486    /// TPM does not report PCRs.
487    /// Returns [`PcrBanksNotAvailable`](crate::TpmDeviceError::PcrBanksNotAvailable)
488    /// if the list of banks is empty or if no banks have allocated PCRs.
489    /// Returns [`PcrBankSelectionMismatch`](crate::TpmDeviceError::PcrBankSelectionMismatch)
490    /// if the PCR selection mask is not identical across all active banks.
491    pub fn fetch_pcr_bank_list(
492        &mut self,
493    ) -> Result<(Vec<TpmAlgId>, TpmsPcrSelect), TpmDeviceError> {
494        let pcrs: Vec<TpmsPcrSelection> = self.get_capability(
495            TpmCap::Pcrs,
496            0,
497            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
498            |caps| match caps {
499                TpmuCapabilities::Pcrs(pcrs) => Ok(pcrs),
500                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Pcrs)),
501            },
502            |last| last.hash as u32 + 1,
503        )?;
504
505        if pcrs.is_empty() {
506            return Err(TpmDeviceError::PcrBanksNotAvailable);
507        }
508
509        let mut common_select: Option<TpmsPcrSelect> = None;
510        let mut algs = Vec::with_capacity(pcrs.len());
511
512        for bank in pcrs {
513            if bank.pcr_select.iter().all(|&b| b == 0) {
514                debug!(
515                    "skipping unallocated bank {:?} (mask: {})",
516                    bank.hash,
517                    hex::encode(&*bank.pcr_select)
518                );
519                continue;
520            }
521
522            if let Some(ref select) = common_select {
523                if bank.pcr_select != *select {
524                    return Err(TpmDeviceError::PcrBankSelectionMismatch);
525                }
526            } else {
527                common_select = Some(bank.pcr_select);
528            }
529            algs.push(bank.hash);
530        }
531
532        let select = common_select.ok_or(TpmDeviceError::PcrBanksNotAvailable)?;
533
534        algs.sort();
535        Ok((algs, select))
536    }
537
538    /// Fetches and returns one page of capabilities of a certain type from the
539    /// TPM.
540    ///
541    /// # Errors
542    ///
543    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
544    /// [`transmit`](TpmDevice::transmit). Returns
545    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
546    /// TPM response does not contain `TPM2_GetCapability` data.
547    fn get_capability_page(
548        &mut self,
549        cap: TpmCap,
550        property: TpmUint32,
551        property_count: TpmUint32,
552    ) -> Result<(bool, TpmsCapabilityData), TpmDeviceError> {
553        let cmd = TpmGetCapabilityCommand {
554            cap,
555            property,
556            property_count,
557            handles: [],
558        };
559
560        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
561        let TpmGetCapabilityResponse {
562            more_data,
563            capability_data,
564            handles: [],
565        } = resp
566            .GetCapability()
567            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::GetCapability))?;
568
569        Ok((more_data.into(), capability_data))
570    }
571
572    /// Reads a specific TPM property.
573    ///
574    /// # Errors
575    ///
576    /// Returns [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing)
577    /// when the TPM does not report the requested property. Propagates any
578    /// [`TpmDeviceError`](crate::TpmDeviceError) from
579    /// [`get_capability_page`](TpmDevice::get_capability_page).
580    pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<TpmUint32, TpmDeviceError> {
581        let (_, cap_data) = self.get_capability_page(
582            TpmCap::TpmProperties,
583            TpmUint32(property as u32),
584            TpmUint32(1),
585        )?;
586
587        let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
588            return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
589        };
590
591        let Some(prop) = props.iter().find(|prop| prop.property == property) else {
592            return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
593        };
594
595        Ok(prop.value)
596    }
597
598    /// Reads the public area of a TPM object.
599    ///
600    /// # Errors
601    ///
602    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
603    /// [`transmit`](TpmDevice::transmit). Returns
604    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
605    /// TPM response does not contain `TPM2_ReadPublic` data.
606    pub fn read_public(
607        &mut self,
608        handle: TpmHandle,
609    ) -> Result<(TpmtPublic, Tpm2bName), TpmDeviceError> {
610        if let Some(cached) = self.name_cache.get(&handle.0) {
611            return Ok(cached.clone());
612        }
613
614        let cmd = TpmReadPublicCommand { handles: [handle] };
615        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
616
617        let read_public_resp = resp
618            .ReadPublic()
619            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
620
621        let public = read_public_resp.out_public.inner;
622        let name = read_public_resp.name;
623
624        self.name_cache.insert(handle.0, (public.clone(), name));
625        Ok((public, name))
626    }
627
628    /// Finds a persistent handle by its `Tpm2bName`.
629    ///
630    /// # Errors
631    ///
632    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
633    /// [`fetch_handles`](TpmDevice::fetch_handles) and
634    /// [`read_public`](TpmDevice::read_public), except for TPM reference and
635    /// handle errors with base
636    /// [`ReferenceH0`](tpm2_protocol::data::TpmRcBase::ReferenceH0) or
637    /// [`Handle`](tpm2_protocol::data::TpmRcBase::Handle), which are treated as
638    /// invalid handles and skipped.
639    pub fn find_persistent(
640        &mut self,
641        target_name: &Tpm2bName,
642    ) -> Result<Option<TpmHandle>, TpmDeviceError> {
643        for handle in self.fetch_handles(TpmHt::Persistent)? {
644            match self.read_public(handle) {
645                Ok((_, name)) => {
646                    if name == *target_name {
647                        return Ok(Some(handle));
648                    }
649                }
650                Err(TpmDeviceError::TpmRc(rc)) => {
651                    let base = rc.base();
652                    if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
653                        continue;
654                    }
655                    return Err(TpmDeviceError::TpmRc(rc));
656                }
657                Err(e) => return Err(e),
658            }
659        }
660        Ok(None)
661    }
662
663    /// Saves the context of a transient object or session.
664    ///
665    /// # Errors
666    ///
667    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
668    /// [`transmit`](TpmDevice::transmit). Returns
669    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
670    /// TPM response does not contain `TPM2_ContextSave` data.
671    pub fn save_context(&mut self, save_handle: TpmHandle) -> Result<TpmsContext, TpmDeviceError> {
672        let cmd = TpmContextSaveCommand {
673            handles: [save_handle],
674        };
675        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
676        let save_resp = resp
677            .ContextSave()
678            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextSave))?;
679        Ok(save_resp.context)
680    }
681
682    /// Loads a TPM context and returns the handle.
683    ///
684    /// # Errors
685    ///
686    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
687    /// [`transmit`](TpmDevice::transmit). Returns
688    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
689    /// TPM response does not contain `TPM2_ContextLoad` data.
690    pub fn load_context(&mut self, context: TpmsContext) -> Result<TpmHandle, TpmDeviceError> {
691        let cmd = TpmContextLoadCommand {
692            context,
693            handles: [],
694        };
695        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
696        let resp_inner = resp
697            .ContextLoad()
698            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
699        Ok(resp_inner.handles[0])
700    }
701
702    /// Flushes a transient object or session from the TPM and removes it from
703    /// the cache.
704    ///
705    /// # Errors
706    ///
707    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
708    /// [`transmit`](TpmDevice::transmit).
709    pub fn flush_context(&mut self, handle: TpmHandle) -> Result<(), TpmDeviceError> {
710        self.name_cache.remove(&handle.0);
711        let cmd = TpmFlushContextCommand {
712            flush_handle: handle,
713            handles: [],
714        };
715        self.transmit(&cmd, Self::NO_SESSIONS)?;
716        Ok(())
717    }
718
719    /// Loads a session context and then flushes the resulting handle.
720    ///
721    /// # Errors
722    ///
723    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
724    /// [`load_context`](TpmDevice::load_context) or
725    /// [`flush_context`](TpmDevice::flush_context) except for TPM reference
726    /// errors with base
727    /// [`ReferenceH0`](tpm2_protocol::data::TpmRcBase::ReferenceH0) or
728    /// [`Handle`](tpm2_protocol::data::TpmRcBase::Handle), which are treated as
729    /// a successful no-op.
730    pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), TpmDeviceError> {
731        match self.load_context(context) {
732            Ok(handle) => self.flush_context(handle),
733            Err(TpmDeviceError::TpmRc(rc)) => {
734                let base = rc.base();
735                if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
736                    Ok(())
737                } else {
738                    Err(TpmDeviceError::TpmRc(rc))
739                }
740            }
741            Err(e) => Err(e),
742        }
743    }
744}
745
746/// A builder for creating a TPM policy session.
747pub struct TpmPolicySessionBuilder {
748    bind: TpmHandle,
749    tpm_key: TpmHandle,
750    nonce_caller: Option<Tpm2bNonce>,
751    encrypted_salt: Option<Tpm2bEncryptedSecret>,
752    session_type: TpmSe,
753    symmetric: TpmtSymDefObject,
754    auth_hash: TpmAlgId,
755}
756
757impl Default for TpmPolicySessionBuilder {
758    fn default() -> Self {
759        Self {
760            bind: (TpmRh::Null as u32).into(),
761            tpm_key: (TpmRh::Null as u32).into(),
762            nonce_caller: None,
763            encrypted_salt: None,
764            session_type: TpmSe::Policy,
765            symmetric: TpmtSymDefObject::default(),
766            auth_hash: TpmAlgId::Sha256,
767        }
768    }
769}
770
771impl TpmPolicySessionBuilder {
772    #[must_use]
773    pub fn new() -> Self {
774        Self::default()
775    }
776
777    #[must_use]
778    pub fn with_bind(mut self, bind: TpmHandle) -> Self {
779        self.bind = bind;
780        self
781    }
782
783    #[must_use]
784    pub fn with_tpm_key(mut self, tpm_key: TpmHandle) -> Self {
785        self.tpm_key = tpm_key;
786        self
787    }
788
789    #[must_use]
790    pub fn with_nonce_caller(mut self, nonce: Tpm2bNonce) -> Self {
791        self.nonce_caller = Some(nonce);
792        self
793    }
794
795    #[must_use]
796    pub fn with_encrypted_salt(mut self, salt: Tpm2bEncryptedSecret) -> Self {
797        self.encrypted_salt = Some(salt);
798        self
799    }
800
801    #[must_use]
802    pub fn with_session_type(mut self, session_type: TpmSe) -> Self {
803        self.session_type = session_type;
804        self
805    }
806
807    #[must_use]
808    pub fn with_symmetric(mut self, symmetric: TpmtSymDefObject) -> Self {
809        self.symmetric = symmetric;
810        self
811    }
812
813    #[must_use]
814    pub fn with_auth_hash(mut self, auth_hash: TpmAlgId) -> Self {
815        self.auth_hash = auth_hash;
816        self
817    }
818
819    /// Opens the policy session on the provided device.
820    ///
821    /// # Errors
822    ///
823    /// Returns [`OutOfMemory`](crate::TpmDeviceError::OutOfMemory) if nonce generation fails.
824    /// Returns [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) if the TPM response is unexpected.
825    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from transmission.
826    pub fn open(self, device: &mut TpmDevice) -> Result<TpmPolicySession, TpmDeviceError> {
827        let nonce_caller = if let Some(nonce) = self.nonce_caller {
828            nonce
829        } else {
830            let digest_len = TpmHash::from(self.auth_hash).size();
831            let mut nonce_bytes = vec![0; digest_len];
832            thread_rng().fill_bytes(&mut nonce_bytes);
833            Tpm2bNonce::try_from(nonce_bytes.as_slice()).map_err(|_| TpmDeviceError::OutOfMemory)?
834        };
835
836        let cmd = TpmStartAuthSessionCommand {
837            nonce_caller,
838            encrypted_salt: self.encrypted_salt.unwrap_or_default(),
839            session_type: self.session_type,
840            symmetric: self.symmetric,
841            auth_hash: self.auth_hash,
842            handles: [self.tpm_key, self.bind],
843        };
844
845        let (resp, _) = device.transmit(&cmd, TpmDevice::NO_SESSIONS)?;
846        let start_resp = resp
847            .StartAuthSession()
848            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::StartAuthSession))?;
849
850        Ok(TpmPolicySession {
851            handle: start_resp.handles[0],
852            attributes: TpmaSession::CONTINUE_SESSION,
853            hash_alg: self.auth_hash,
854            nonce_tpm: start_resp.nonce_tpm,
855        })
856    }
857}
858
859/// Represents an active TPM policy session.
860#[derive(Debug, Clone)]
861pub struct TpmPolicySession {
862    handle: TpmHandle,
863    attributes: TpmaSession,
864    hash_alg: TpmAlgId,
865    nonce_tpm: Tpm2bNonce,
866}
867
868impl TpmPolicySession {
869    /// Creates a new builder for `TpmPolicySession`.
870    #[must_use]
871    pub fn builder() -> TpmPolicySessionBuilder {
872        TpmPolicySessionBuilder::new()
873    }
874
875    /// Returns the session handle.
876    #[must_use]
877    pub fn handle(&self) -> TpmHandle {
878        self.handle
879    }
880
881    /// Returns the session attributes.
882    #[must_use]
883    pub fn attributes(&self) -> TpmaSession {
884        self.attributes
885    }
886
887    /// Returns the hash algorithm used by the session.
888    #[must_use]
889    pub fn hash_alg(&self) -> TpmAlgId {
890        self.hash_alg
891    }
892
893    /// Returns the nonce generated by the TPM.
894    #[must_use]
895    pub fn nonce_tpm(&self) -> &Tpm2bNonce {
896        &self.nonce_tpm
897    }
898
899    /// Applies a list of policy commands to this session.
900    ///
901    /// This method iterates through the provided commands, updates the first handle
902    /// of each command (or second for `PolicySecret`) to point to this session,
903    /// and transmits them to the device.
904    ///
905    /// # Errors
906    ///
907    /// Returns [`MalformedData`](crate::TpmDeviceError::MalformedData) if a command
908    /// structure is not recognized as a supported policy command.
909    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from transmission.
910    pub fn run(
911        &self,
912        device: &mut TpmDevice,
913        commands: Vec<(TpmCommand, TpmAuthCommands)>,
914    ) -> Result<(), TpmDeviceError> {
915        for (mut command_body, auth_sessions) in commands {
916            match &mut command_body {
917                TpmCommand::PolicyPcr(cmd) => cmd.handles[0] = self.handle,
918                TpmCommand::PolicyOr(cmd) => cmd.handles[0] = self.handle,
919                TpmCommand::PolicyRestart(cmd) => {
920                    cmd.handles[0] = self.handle;
921                }
922                TpmCommand::PolicySecret(cmd) => {
923                    cmd.handles[1] = self.handle;
924                }
925                _ => {
926                    return Err(TpmDeviceError::MalformedData);
927                }
928            }
929            device.transmit(&command_body, auth_sessions.as_ref())?;
930        }
931        Ok(())
932    }
933
934    /// Flushes the session context from the TPM.
935    ///
936    /// # Errors
937    ///
938    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from [`flush_context`](TpmDevice::flush_context).
939    pub fn flush(&self, device: &mut TpmDevice) -> Result<(), TpmDeviceError> {
940        device.flush_context(self.handle)
941    }
942}