tpm2_device/
lib.rs

1// SPDX-License-Identifier: GPL-3-0-or-later
2// Copyright (c) 2025 Opinsys Oy
3// Copyright (c) 2024-2025 Jarkko Sakkinen
4
5#![deny(clippy::all)]
6#![deny(clippy::pedantic)]
7
8use nix::{
9    fcntl,
10    poll::{poll, PollFd, PollFlags},
11};
12use std::{
13    cell::RefCell,
14    collections::HashMap,
15    fs::{File, OpenOptions},
16    io::{Read, Write},
17    os::fd::{AsFd, AsRawFd},
18    path::{Path, PathBuf},
19    rc::Rc,
20    time::{Duration, Instant},
21};
22
23use thiserror::Error;
24use tpm2_protocol::{
25    basic::{TpmHandle, TpmUint32},
26    constant::{MAX_HANDLES, TPM_MAX_COMMAND_SIZE},
27    data::{
28        Tpm2bName, TpmAlgId, TpmCap, TpmCc, TpmEccCurve, TpmHt, TpmPt, TpmRc, TpmRcBase, TpmSt,
29        TpmsAlgProperty, TpmsAuthCommand, TpmsCapabilityData, TpmsContext, TpmsPcrSelection,
30        TpmtPublic, TpmuCapabilities,
31    },
32    frame::{
33        tpm_marshal_command, tpm_unmarshal_response, TpmAuthResponses, TpmContextLoadCommand,
34        TpmContextSaveCommand, TpmFlushContextCommand, TpmFrame, TpmGetCapabilityCommand,
35        TpmGetCapabilityResponse, TpmReadPublicCommand, TpmResponse,
36    },
37    TpmWriter,
38};
39use tracing::trace;
40
41/// Errors that can occur when talking to a TPM device.
42#[derive(Debug, Error)]
43pub enum TpmDeviceError {
44    #[error("device is already borrowed")]
45    AlreadyBorrowed,
46    #[error("capability not found: {0}")]
47    CapabilityMissing(TpmCap),
48    #[error("operation interrupted by user")]
49    Interrupted,
50    #[error("invalid response")]
51    InvalidResponse,
52
53    #[error("I/O: {0}")]
54    Io(#[from] std::io::Error),
55
56    /// Marshaling a TPM protocol encoded object failed.
57    #[error("marshal: {0}")]
58    Marshal(tpm2_protocol::TpmProtocolError),
59
60    #[error("device not available")]
61    NotAvailable,
62    #[error("operation failed")]
63    OperationFailed,
64    #[error("PCR banks not available")]
65    PcrBanksNotAvailable,
66    #[error("PCR bank size mismatch")]
67    PcrBankSizeMismatch,
68
69    /// The TPM response did not match the expected command code.
70    #[error("response mismatch: {0}")]
71    ResponseMismatch(TpmCc),
72
73    #[error("TPM command timed out")]
74    Timeout,
75    #[error("TPM return code: {0}")]
76    TpmRc(TpmRc),
77
78    /// Unmarshaling a TPM protocol encoded object failed.
79    #[error("unmarshal: {0}")]
80    Unmarshal(tpm2_protocol::TpmProtocolError),
81
82    #[error("unexpected EOF")]
83    UnexpectedEof,
84}
85
86impl From<TpmRc> for TpmDeviceError {
87    fn from(rc: TpmRc) -> Self {
88        Self::TpmRc(rc)
89    }
90}
91
92impl From<nix::Error> for TpmDeviceError {
93    fn from(err: nix::Error) -> Self {
94        Self::Io(std::io::Error::from_raw_os_error(err as i32))
95    }
96}
97
98/// Executes a closure with a mutable reference to a `TpmDevice`.
99///
100/// This helper function centralizes the boilerplate for safely acquiring a
101/// mutable borrow of a `TpmDevice` from the shared `Rc<RefCell<...>>`.
102///
103/// # Errors
104///
105/// Returns [`NotAvailable`](crate::TpmDeviceError::NotAvailable) when no device
106/// is present and [`AlreadyBorrowed`](crate::TpmDeviceError::AlreadyBorrowed)
107/// when the device is already mutably borrowed, both converted into the caller's
108/// error type `E`. Propagates any error returned by the closure `f`.
109pub fn with_device<F, T, E>(device: Option<Rc<RefCell<TpmDevice>>>, f: F) -> Result<T, E>
110where
111    F: FnOnce(&mut TpmDevice) -> Result<T, E>,
112    E: From<TpmDeviceError>,
113{
114    let device_rc = device.ok_or(TpmDeviceError::NotAvailable)?;
115    let mut device_guard = device_rc
116        .try_borrow_mut()
117        .map_err(|_| TpmDeviceError::AlreadyBorrowed)?;
118    f(&mut device_guard)
119}
120
121/// A builder for constructing a `TpmDevice`.
122pub struct TpmDeviceBuilder {
123    path: PathBuf,
124    timeout: Duration,
125    interrupted: Box<dyn Fn() -> bool>,
126}
127
128impl Default for TpmDeviceBuilder {
129    fn default() -> Self {
130        Self {
131            path: PathBuf::from("/dev/tpmrm0"),
132            timeout: Duration::from_secs(120),
133            interrupted: Box::new(|| false),
134        }
135    }
136}
137
138impl TpmDeviceBuilder {
139    /// Sets the device file path.
140    #[must_use]
141    pub fn with_path<P: AsRef<Path>>(mut self, path: P) -> Self {
142        self.path = path.as_ref().to_path_buf();
143        self
144    }
145
146    /// Sets the operation timeout.
147    #[must_use]
148    pub fn with_timeout(mut self, timeout: Duration) -> Self {
149        self.timeout = timeout;
150        self
151    }
152
153    /// Sets the interruption check callback.
154    #[must_use]
155    pub fn with_interrupted<F>(mut self, handler: F) -> Self
156    where
157        F: Fn() -> bool + 'static,
158    {
159        self.interrupted = Box::new(handler);
160        self
161    }
162
163    /// Opens the TPM device file and constructs the `TpmDevice`.
164    ///
165    /// # Errors
166    ///
167    /// Returns [`Io`](crate::TpmDeviceError::Io) when the device file cannot be
168    /// opened or when configuring the file descriptor flags fails.
169    pub fn build(self) -> Result<TpmDevice, TpmDeviceError> {
170        let file = OpenOptions::new()
171            .read(true)
172            .write(true)
173            .open(&self.path)
174            .map_err(TpmDeviceError::Io)?;
175
176        let fd = file.as_raw_fd();
177        let flags = fcntl::fcntl(fd, fcntl::FcntlArg::F_GETFL)?;
178        let mut oflags = fcntl::OFlag::from_bits_truncate(flags);
179        oflags.insert(fcntl::OFlag::O_NONBLOCK);
180        fcntl::fcntl(fd, fcntl::FcntlArg::F_SETFL(oflags))?;
181
182        Ok(TpmDevice {
183            file,
184            name_cache: HashMap::new(),
185            interrupted: self.interrupted,
186            timeout: self.timeout,
187            command: Vec::with_capacity(TPM_MAX_COMMAND_SIZE),
188            response: Vec::with_capacity(TPM_MAX_COMMAND_SIZE),
189        })
190    }
191}
192
193pub struct TpmDevice {
194    file: File,
195    name_cache: HashMap<u32, (TpmtPublic, Tpm2bName)>,
196    interrupted: Box<dyn Fn() -> bool>,
197    timeout: Duration,
198    command: Vec<u8>,
199    response: Vec<u8>,
200}
201
202impl std::fmt::Debug for TpmDevice {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        f.debug_struct("Device")
205            .field("file", &self.file)
206            .field("name_cache", &self.name_cache)
207            .field("timeout", &self.timeout)
208            .finish_non_exhaustive()
209    }
210}
211
212impl TpmDevice {
213    const NO_SESSIONS: &'static [TpmsAuthCommand] = &[];
214
215    /// Creates a new builder for `TpmDevice`.
216    #[must_use]
217    pub fn builder() -> TpmDeviceBuilder {
218        TpmDeviceBuilder::default()
219    }
220
221    fn receive(&mut self, buf: &mut [u8]) -> Result<usize, TpmDeviceError> {
222        let fd = self.file.as_fd();
223        let mut fds = [PollFd::new(fd, PollFlags::POLLIN)];
224
225        let num_events = match poll(&mut fds, 100u16) {
226            Ok(num) => num,
227            Err(nix::Error::EINTR) => return Ok(0),
228            Err(e) => return Err(e.into()),
229        };
230
231        if num_events == 0 {
232            return Ok(0);
233        }
234
235        let revents = fds[0].revents().unwrap_or(PollFlags::empty());
236
237        if revents.intersects(PollFlags::POLLERR | PollFlags::POLLNVAL) {
238            return Err(TpmDeviceError::UnexpectedEof);
239        }
240
241        if revents.contains(PollFlags::POLLIN) {
242            match self.file.read(buf) {
243                Ok(0) => Err(TpmDeviceError::UnexpectedEof),
244                Ok(n) => Ok(n),
245                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => Ok(0),
246                Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok(0),
247                Err(e) => Err(e.into()),
248            }
249        } else if revents.contains(PollFlags::POLLHUP) {
250            Err(TpmDeviceError::UnexpectedEof)
251        } else {
252            Ok(0)
253        }
254    }
255
256    /// Performs the whole TPM command transmission process.
257    ///
258    /// # Errors
259    ///
260    /// Returns [`Interrupted`](crate::TpmDeviceError::Interrupted) when the
261    /// interrupt callback requests cancellation.
262    /// Returns [`Timeout`](crate::TpmDeviceError::Timeout) when the TPM does
263    /// not respond within the configured timeout.
264    /// Returns [`Io`](crate::TpmDeviceError::Io) when a write, flush, or read
265    /// operation on the device file fails, or when polling the device file
266    /// descriptor fails.
267    /// Returns [`InvalidResponse`](crate::TpmDeviceError::InvalidResponse) or
268    /// [`UnexpectedEof`](crate::TpmDeviceError::UnexpectedEof) when the TPM
269    /// reply is malformed, truncated, or longer than the announced size.
270    /// Returns [`Marshal`](crate::TpmDeviceError::Marshal) or
271    /// [`Unmarshal`](crate::TpmDeviceError::Unmarshal) when encoding the
272    /// command or decoding the response fails.
273    /// Returns [`TpmRc`](crate::TpmDeviceError::TpmRc) when the TPM returns an
274    /// error code.
275    pub fn transmit<C: TpmFrame>(
276        &mut self,
277        command: &C,
278        sessions: &[TpmsAuthCommand],
279    ) -> Result<(TpmResponse, TpmAuthResponses), TpmDeviceError> {
280        self.prepare_command(command, sessions)?;
281        let cc = command.cc();
282
283        self.file.write_all(&self.command)?;
284        self.file.flush()?;
285
286        let start_time = Instant::now();
287        self.response.clear();
288        let mut total_size: Option<usize> = None;
289        let mut temp_buf = [0u8; 1024];
290
291        loop {
292            if (self.interrupted)() {
293                return Err(TpmDeviceError::Interrupted);
294            }
295            if start_time.elapsed() > self.timeout {
296                return Err(TpmDeviceError::Timeout);
297            }
298
299            let n = self.receive(&mut temp_buf)?;
300            if n > 0 {
301                self.response.extend_from_slice(&temp_buf[..n]);
302            }
303
304            if total_size.is_none() && self.response.len() >= 10 {
305                let Ok(size_bytes): Result<[u8; 4], _> = self.response[2..6].try_into() else {
306                    return Err(TpmDeviceError::InvalidResponse);
307                };
308                let size = u32::from_be_bytes(size_bytes) as usize;
309                if !(10..={ TPM_MAX_COMMAND_SIZE }).contains(&size) {
310                    return Err(TpmDeviceError::InvalidResponse);
311                }
312                total_size = Some(size);
313            }
314
315            if let Some(size) = total_size {
316                if self.response.len() == size {
317                    break;
318                }
319                if self.response.len() > size {
320                    return Err(TpmDeviceError::InvalidResponse);
321                }
322            }
323        }
324
325        let result = tpm_unmarshal_response(cc, &self.response).map_err(TpmDeviceError::Unmarshal);
326        trace!("{} R: {}", cc, hex::encode(&self.response));
327        Ok(result??)
328    }
329
330    fn prepare_command<C: TpmFrame>(
331        &mut self,
332        command: &C,
333        sessions: &[TpmsAuthCommand],
334    ) -> Result<(), TpmDeviceError> {
335        let cc = command.cc();
336        let tag = if sessions.is_empty() {
337            TpmSt::NoSessions
338        } else {
339            TpmSt::Sessions
340        };
341
342        self.command.resize(TPM_MAX_COMMAND_SIZE, 0);
343
344        let len = {
345            let mut writer = TpmWriter::new(&mut self.command);
346            tpm_marshal_command(command, tag, sessions, &mut writer)
347                .map_err(TpmDeviceError::Marshal)?;
348            writer.len()
349        };
350        self.command.truncate(len);
351
352        trace!("{} C: {}", cc, hex::encode(&self.command));
353        Ok(())
354    }
355
356    /// Fetches a complete list of capabilities from the TPM, handling
357    /// pagination.
358    ///
359    /// # Errors
360    ///
361    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) returned by
362    /// [`get_capability_page`](TpmDevice::get_capability_page) or by the
363    /// `extract` closure.
364    fn get_capability<T, F, N>(
365        &mut self,
366        cap: TpmCap,
367        property_start: u32,
368        count: u32,
369        mut extract: F,
370        next_prop: N,
371    ) -> Result<Vec<T>, TpmDeviceError>
372    where
373        T: Copy,
374        F: for<'a> FnMut(&'a TpmuCapabilities) -> Result<&'a [T], TpmDeviceError>,
375        N: Fn(&T) -> u32,
376    {
377        let mut results = Vec::new();
378        let mut prop = property_start;
379        loop {
380            let (more_data, cap_data) =
381                self.get_capability_page(cap, TpmUint32(prop), TpmUint32(count))?;
382            let items: &[T] = extract(&cap_data.data)?;
383            results.extend_from_slice(items);
384
385            if more_data {
386                if let Some(last) = items.last() {
387                    prop = next_prop(last);
388                } else {
389                    break;
390                }
391            } else {
392                break;
393            }
394        }
395        Ok(results)
396    }
397
398    /// Retrieves all algorithm properties supported by the TPM.
399    ///
400    /// # Errors
401    ///
402    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
403    /// the handle count cannot be represented as `u32`. Propagates any
404    /// [`TpmDeviceError`](crate::TpmDeviceError) from
405    /// [`get_capability`](TpmDevice::get_capability), including
406    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
407    /// TPM does not report algorithm properties.
408    pub fn fetch_algorithm_properties(&mut self) -> Result<Vec<TpmsAlgProperty>, TpmDeviceError> {
409        self.get_capability(
410            TpmCap::Algs,
411            0,
412            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
413            |caps| match caps {
414                TpmuCapabilities::Algs(algs) => Ok(algs),
415                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Algs)),
416            },
417            |last| last.alg as u32 + 1,
418        )
419    }
420
421    /// Retrieves all handles of a specific type from the TPM.
422    ///
423    /// # Errors
424    ///
425    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
426    /// the handle count cannot be represented as `u32`. Propagates any
427    /// [`TpmDeviceError`](crate::TpmDeviceError) from
428    /// [`get_capability`](TpmDevice::get_capability), including
429    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
430    /// TPM does not report handles of the requested class.
431    pub fn fetch_handles(&mut self, class: TpmHt) -> Result<Vec<TpmHandle>, TpmDeviceError> {
432        self.get_capability(
433            TpmCap::Handles,
434            (class as u32) << 24,
435            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
436            |caps| match caps {
437                TpmuCapabilities::Handles(handles) => Ok(handles),
438                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Handles)),
439            },
440            |last| last.value() + 1,
441        )
442        .map(|handles| handles.into_iter().collect())
443    }
444
445    /// Retrieves all available ECC curves supported by the TPM.
446    ///
447    /// # Errors
448    ///
449    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
450    /// the handle count cannot be represented as `u32`. Propagates any
451    /// [`TpmDeviceError`](crate::TpmDeviceError) from
452    /// [`get_capability`](TpmDevice::get_capability), including
453    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
454    /// TPM does not report ECC curves.
455    pub fn fetch_ecc_curves(&mut self) -> Result<Vec<TpmEccCurve>, TpmDeviceError> {
456        self.get_capability(
457            TpmCap::EccCurves,
458            0,
459            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
460            |caps| match caps {
461                TpmuCapabilities::EccCurves(curves) => Ok(curves),
462                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::EccCurves)),
463            },
464            |last| *last as u32 + 1,
465        )
466    }
467
468    /// Retrieves the list of active PCR banks and the bank size.
469    ///
470    /// # Errors
471    ///
472    /// Returns [`OperationFailed`](crate::TpmDeviceError::OperationFailed) when
473    /// the handle count cannot be represented as `u32`. Propagates any
474    /// [`TpmDeviceError`](crate::TpmDeviceError) from
475    /// [`get_capability`](TpmDevice::get_capability), including
476    /// [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing) when the
477    /// TPM does not report PCRs.
478    /// Returns [`PcrBanksNotAvailable`](crate::TpmDeviceError::PcrBanksNotAvailable)
479    /// if the list of banks is empty.
480    /// Returns [`PcrBankSizeMismatch`](crate::TpmDeviceError::PcrBankSizeMismatch)
481    /// if bank sizes are inconsistent.
482    pub fn fetch_pcr_bank_list(&mut self) -> Result<(usize, Vec<TpmAlgId>), TpmDeviceError> {
483        let pcrs: Vec<TpmsPcrSelection> = self.get_capability(
484            TpmCap::Pcrs,
485            0,
486            u32::try_from(MAX_HANDLES).map_err(|_| TpmDeviceError::OperationFailed)?,
487            |caps| match caps {
488                TpmuCapabilities::Pcrs(pcrs) => Ok(pcrs),
489                _ => Err(TpmDeviceError::CapabilityMissing(TpmCap::Pcrs)),
490            },
491            |last| last.hash as u32 + 1,
492        )?;
493
494        if pcrs.is_empty() {
495            return Err(TpmDeviceError::PcrBanksNotAvailable);
496        }
497
498        let mut count = 0;
499        let mut algs = Vec::with_capacity(pcrs.len());
500
501        for bank in pcrs {
502            let next_count = bank.pcr_select.len();
503            if count == 0 {
504                count = next_count;
505            }
506            if next_count != count {
507                return Err(TpmDeviceError::PcrBankSizeMismatch);
508            }
509            algs.push(bank.hash);
510        }
511
512        algs.sort();
513        Ok((count, algs))
514    }
515
516    /// Fetches and returns one page of capabilities of a certain type from the
517    /// TPM.
518    ///
519    /// # Errors
520    ///
521    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
522    /// [`transmit`](TpmDevice::transmit). Returns
523    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
524    /// TPM response does not contain `TPM2_GetCapability` data.
525    fn get_capability_page(
526        &mut self,
527        cap: TpmCap,
528        property: TpmUint32,
529        property_count: TpmUint32,
530    ) -> Result<(bool, TpmsCapabilityData), TpmDeviceError> {
531        let cmd = TpmGetCapabilityCommand {
532            cap,
533            property,
534            property_count,
535            handles: [],
536        };
537
538        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
539        let TpmGetCapabilityResponse {
540            more_data,
541            capability_data,
542            handles: [],
543        } = resp
544            .GetCapability()
545            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::GetCapability))?;
546
547        Ok((more_data.into(), capability_data))
548    }
549
550    /// Reads a specific TPM property.
551    ///
552    /// # Errors
553    ///
554    /// Returns [`CapabilityMissing`](crate::TpmDeviceError::CapabilityMissing)
555    /// when the TPM does not report the requested property. Propagates any
556    /// [`TpmDeviceError`](crate::TpmDeviceError) from
557    /// [`get_capability_page`](TpmDevice::get_capability_page).
558    pub fn get_tpm_property(&mut self, property: TpmPt) -> Result<TpmUint32, TpmDeviceError> {
559        let (_, cap_data) = self.get_capability_page(
560            TpmCap::TpmProperties,
561            TpmUint32(property as u32),
562            TpmUint32(1),
563        )?;
564
565        let TpmuCapabilities::TpmProperties(props) = &cap_data.data else {
566            return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
567        };
568
569        let Some(prop) = props.iter().find(|prop| prop.property == property) else {
570            return Err(TpmDeviceError::CapabilityMissing(TpmCap::TpmProperties));
571        };
572
573        Ok(prop.value)
574    }
575
576    /// Reads the public area of a TPM object.
577    ///
578    /// # Errors
579    ///
580    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
581    /// [`transmit`](TpmDevice::transmit). Returns
582    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
583    /// TPM response does not contain `TPM2_ReadPublic` data.
584    pub fn read_public(
585        &mut self,
586        handle: TpmHandle,
587    ) -> Result<(TpmtPublic, Tpm2bName), TpmDeviceError> {
588        if let Some(cached) = self.name_cache.get(&handle.0) {
589            return Ok(cached.clone());
590        }
591
592        let cmd = TpmReadPublicCommand { handles: [handle] };
593        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
594
595        let read_public_resp = resp
596            .ReadPublic()
597            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ReadPublic))?;
598
599        let public = read_public_resp.out_public.inner;
600        let name = read_public_resp.name;
601
602        self.name_cache.insert(handle.0, (public.clone(), name));
603        Ok((public, name))
604    }
605
606    /// Finds a persistent handle by its `Tpm2bName`.
607    ///
608    /// # Errors
609    ///
610    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
611    /// [`fetch_handles`](TpmDevice::fetch_handles) and
612    /// [`read_public`](TpmDevice::read_public), except for TPM reference and
613    /// handle errors with base
614    /// [`ReferenceH0`](tpm2_protocol::data::TpmRcBase::ReferenceH0) or
615    /// [`Handle`](tpm2_protocol::data::TpmRcBase::Handle), which are treated as
616    /// invalid handles and skipped.
617    pub fn find_persistent(
618        &mut self,
619        target_name: &Tpm2bName,
620    ) -> Result<Option<TpmHandle>, TpmDeviceError> {
621        for handle in self.fetch_handles(TpmHt::Persistent)? {
622            match self.read_public(handle) {
623                Ok((_, name)) => {
624                    if name == *target_name {
625                        return Ok(Some(handle));
626                    }
627                }
628                Err(TpmDeviceError::TpmRc(rc)) => {
629                    let base = rc.base();
630                    if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
631                        continue;
632                    }
633                    return Err(TpmDeviceError::TpmRc(rc));
634                }
635                Err(e) => return Err(e),
636            }
637        }
638        Ok(None)
639    }
640
641    /// Saves the context of a transient object or session.
642    ///
643    /// # Errors
644    ///
645    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
646    /// [`transmit`](TpmDevice::transmit). Returns
647    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
648    /// TPM response does not contain `TPM2_ContextSave` data.
649    pub fn save_context(&mut self, save_handle: TpmHandle) -> Result<TpmsContext, TpmDeviceError> {
650        let cmd = TpmContextSaveCommand {
651            handles: [save_handle],
652        };
653        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
654        let save_resp = resp
655            .ContextSave()
656            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextSave))?;
657        Ok(save_resp.context)
658    }
659
660    /// Loads a TPM context and returns the handle.
661    ///
662    /// # Errors
663    ///
664    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
665    /// [`transmit`](TpmDevice::transmit). Returns
666    /// [`ResponseMismatch`](crate::TpmDeviceError::ResponseMismatch) when the
667    /// TPM response does not contain `TPM2_ContextLoad` data.
668    pub fn load_context(&mut self, context: TpmsContext) -> Result<TpmHandle, TpmDeviceError> {
669        let cmd = TpmContextLoadCommand {
670            context,
671            handles: [],
672        };
673        let (resp, _) = self.transmit(&cmd, Self::NO_SESSIONS)?;
674        let resp_inner = resp
675            .ContextLoad()
676            .map_err(|_| TpmDeviceError::ResponseMismatch(TpmCc::ContextLoad))?;
677        Ok(resp_inner.handles[0])
678    }
679
680    /// Flushes a transient object or session from the TPM and removes it from
681    /// the cache.
682    ///
683    /// # Errors
684    ///
685    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
686    /// [`transmit`](TpmDevice::transmit).
687    pub fn flush_context(&mut self, handle: TpmHandle) -> Result<(), TpmDeviceError> {
688        self.name_cache.remove(&handle.0);
689        let cmd = TpmFlushContextCommand {
690            flush_handle: handle,
691            handles: [],
692        };
693        self.transmit(&cmd, Self::NO_SESSIONS)?;
694        Ok(())
695    }
696
697    /// Loads a session context and then flushes the resulting handle.
698    ///
699    /// # Errors
700    ///
701    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
702    /// [`load_context`](TpmDevice::load_context) or
703    /// [`flush_context`](TpmDevice::flush_context) except for TPM reference
704    /// errors with base
705    /// [`ReferenceH0`](tpm2_protocol::data::TpmRcBase::ReferenceH0) or
706    /// [`Handle`](tpm2_protocol::data::TpmRcBase::Handle), which are treated as
707    /// a successful no-op.
708    pub fn flush_session(&mut self, context: TpmsContext) -> Result<(), TpmDeviceError> {
709        match self.load_context(context) {
710            Ok(handle) => self.flush_context(handle),
711            Err(TpmDeviceError::TpmRc(rc)) => {
712                let base = rc.base();
713                if base == TpmRcBase::ReferenceH0 || base == TpmRcBase::Handle {
714                    Ok(())
715                } else {
716                    Err(TpmDeviceError::TpmRc(rc))
717                }
718            }
719            Err(e) => Err(e),
720        }
721    }
722
723    /// Refreshes a key context. Returns `true` if the context is still valid,
724    /// and `false` if it is stale.
725    ///
726    /// # Errors
727    ///
728    /// Propagates any [`TpmDeviceError`](crate::TpmDeviceError) from
729    /// [`load_context`](TpmDevice::load_context) or
730    /// [`flush_context`](TpmDevice::flush_context) except for TPM reference
731    /// errors with base
732    /// [`ReferenceH0`](tpm2_protocol::data::TpmRcBase::ReferenceH0), which are
733    /// treated as a stale context and reported as `Ok(false)`.
734    pub fn refresh_key(&mut self, context: TpmsContext) -> Result<bool, TpmDeviceError> {
735        match self.load_context(context) {
736            Ok(handle) => match self.flush_context(handle) {
737                Ok(()) => Ok(true),
738                Err(e) => Err(e),
739            },
740            Err(TpmDeviceError::TpmRc(rc)) if rc.base() == TpmRcBase::ReferenceH0 => Ok(false),
741            Err(e) => Err(e),
742        }
743    }
744}