Skip to main content

csv_rs/api/dcu/
mod.rs

1// Copyright (C) Hygon Info Technologies Ltd.
2//
3// SPDX-License-Identifier: Apache-2.0
4//
5
6use crate::error::*;
7mod ioctl;
8pub use ioctl::*;
9mod types;
10use crate::certs::{builtin::HRK, ca, csv, Verifiable};
11use codicon::Decoder;
12use log::*;
13use std::fs::{self, File, OpenOptions};
14use std::io::{self};
15use std::path::Path;
16pub use types::*;
17
18/// Reads the DCU ID from the sysfs topology node.
19fn topology_sysfs_get_dcu_id(sysfs_node_id: u32) -> io::Result<u32> {
20    let path = format!(
21        "/sys/devices/virtual/kfd/kfd/topology/nodes/{}/gpu_id",
22        sysfs_node_id
23    );
24    fs::read_to_string(&path)?
25        .trim()
26        .parse::<u32>()
27        .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Failed to parse DCU ID"))
28}
29
30/// Counts the number of subdirectories in a given directory with an optional prefix filter.
31fn num_subdirs(dirpath: &str, prefix: &str) -> usize {
32    fs::read_dir(dirpath)
33        .map(|entries| {
34            entries
35                .filter_map(Result::ok)
36                .filter(|entry| {
37                    let name = entry.file_name();
38                    let name_lossy = name.to_string_lossy();
39                    !(name_lossy == "." || name_lossy == "..")
40                        && (prefix.is_empty() || name_lossy.starts_with(prefix))
41                })
42                .count()
43        })
44        .unwrap_or(0)
45}
46
47/// A handle to the dcu device.
48pub struct DcuDevice(File);
49
50impl DcuDevice {
51    /// Opens a handle to the DCU device via `/dev/mkfd`.
52    pub fn new() -> io::Result<DcuDevice> {
53        OpenOptions::new()
54            .read(true)
55            .write(true)
56            .open("/dev/mkfd")
57            .map(DcuDevice)
58    }
59
60    /// Get attestation reports from all available DCU nodes
61    ///
62    /// # Arguments
63    /// * `userdata` - 64-byte user data value used for attestation request
64    ///
65    /// # Returns
66    /// - `Ok(Vec<AttestationReport>)` containing valid attestation reports
67    /// - `Err(Error)` if:
68    ///   - No valid reports are obtained
69    ///   - IOCTL operations fail
70    ///   - DCU node communication fails
71    pub fn get_report(&mut self, userdata: [u8; 64]) -> Result<Vec<AttestationReport>, Error> {
72        // Discover available DCU nodes
73        let num_node = num_subdirs("/sys/devices/virtual/kfd/kfd/topology/nodes", "");
74        let mut reports: Vec<AttestationReport> = Vec::with_capacity(num_node);
75
76        // Process each DCU node
77        for node in 0..num_node {
78            trace!("Processing node {} of {}", node, num_node);
79
80            if let Ok(dcu_id) = topology_sysfs_get_dcu_id(node as u32) {
81                trace!("Found DCU ID: {}", dcu_id);
82
83                // Skip invalid DCU IDs
84                if dcu_id == 0 {
85                    continue;
86                }
87
88                // Initialize attestation request
89                let mut args = MkfdIoctlSecurityAttestationArgs::new();
90                args.set_attestation_args(dcu_id, userdata)?;
91
92                // Execute IOCTL request
93                if DCU_GET_REPORT.ioctl(&mut self.0, &mut args)? == 0 {
94                    if let Some(report) = args.extract_report()? {
95                        debug!(
96                            "Get dcu report succeeded - Node: {}, DCU ID: {}",
97                            node, dcu_id
98                        );
99                        // Debug output and storage
100                        report.print_report();
101                        reports.push(report);
102                    }
103                }
104            }
105        }
106
107        // Validate we got at least one report
108        if reports.is_empty() {
109            Err(io::Error::new(
110                io::ErrorKind::NotFound,
111                "No valid attestation reports obtained from any DCU node",
112            )
113            .into())
114        } else {
115            Ok(reports)
116        }
117    }
118}
119
120/// Verifies multiple attestation reports asynchronously.
121///
122/// Iterates through each report, retrieves the corresponding certificate (either from local storage
123/// or by downloading), and performs full verification including certificate chain validation and nonce matching.
124///
125/// # Arguments
126/// * `reports` - Slice of [`AttestationReport`] structures to verify
127/// * `userdata` - 64-byte expected nonce value (must match each report's embedded user data)
128///
129/// # Returns
130/// * `Ok(())` if all reports pass verification
131/// * `Err(Error)` containing the first encountered verification failure
132#[cfg(feature = "network")]
133pub async fn verify_reports(
134    reports: &[AttestationReport],
135    userdata: &[u8; 64],
136) -> Result<(), Error> {
137    for report in reports {
138        let cert_data = csv::cert::get_certificate_data(&report.body.chip_id).await?;
139        verify_report(report, userdata, &cert_data)?;
140    }
141    Ok(())
142}
143
144/// Performs complete verification of a single attestation report.
145///
146/// # Verification Pipeline
147/// 1. ​**Nonce Verification**:
148///    - Compares provided mnonce with report's embedded nonce
149/// 2. ​**Certificate Chain Decoding**:
150///    - HRK (Hygon Root Key) ← Predefined
151///    - HSK (Hygon Signing Key) ← From cert_data
152///    - CEK (Chip Endorsement Key) ← From cert_data
153/// 3. ​**Certificate Chain Validation**:
154///    - HRK → HSK → CEK → Report signature
155///
156/// # Arguments
157/// * `report` - Individual attestation report to verify
158/// * `mnonce` - Expected 16-byte nonce value
159/// * `cert_data` - DER-encoded certificate chain (HSK + CEK)
160///
161/// # Errors
162/// Returns specific validation errors for:
163/// - Certificate decoding failures
164/// - Chain validation failures
165/// - Nonce mismatches
166pub fn verify_report(
167    report: &AttestationReport,
168    userdata: &[u8; 64],
169    cert_data: &[u8],
170) -> Result<(), Error> {
171    let mut cert_slice = cert_data;
172
173    // Decode certificate chain
174    let hsk = ca::Certificate::decode(&mut cert_slice, ())?;
175    let cek = csv::Certificate::decode(&mut cert_slice, ())?;
176    let hrk = ca::Certificate::decode(&mut &HRK[..], ())?;
177
178    report.print_report();
179
180    // Critical security check: nonce matching
181    if userdata != &report.body.user_data {
182        return Err(
183            io::Error::new(io::ErrorKind::InvalidData, "Attestation nonce mismatch").into(),
184        );
185    }
186
187    // Validate certificate hierarchy
188    (&hrk, &hrk).verify()?; // HRK self-verification
189    (&hrk, &hsk).verify()?; // HRK → HSK
190    (&hsk, &cek).verify()?; // HSK → CEK
191    (&cek, report).verify()?; // CEK → Report
192
193    debug!(
194        "Successfully verified report for Chip ID: {}",
195        String::from_utf8_lossy(&report.body.chip_id)
196    );
197
198    Ok(())
199}
200
201/// Saves certificates to local files
202pub fn save_certificates(
203    hsk: &ca::Certificate,
204    cek: &csv::Certificate,
205    hrk: &ca::Certificate,
206    chip_id: [u8; 16],
207) -> Result<(), Error> {
208    // Define certificates directory path
209    let certs_dir = Path::new("/opt/dcu/certs");
210
211    // Create directory recursively if it doesn't exist (similar to mkdir -p)
212    if !certs_dir.exists() {
213        fs::create_dir_all(certs_dir).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
214    }
215
216    // Write HSK certificate
217    hsk.write_to_file(&certs_dir.join("hsk.cert"))?;
218
219    // Convert chip_id to string
220    let chip_id_str = String::from_utf8(chip_id.to_vec())
221        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
222
223    // Write CEK certificate (with chip_id in filename)
224    cek.write_to_file(&certs_dir.join(format!("{}_cek.cert", chip_id_str)))?;
225
226    // Write HRK certificate
227    hrk.write_to_file(&certs_dir.join("hrk.cert"))?;
228
229    Ok(())
230}