Documentation
// SPDX-License-Identifier: GPL-3-0-or-later
// Copyright (c) 2025 Opinsys Oy
// Copyright (c) 2024-2025 Jarkko Sakkinen

use std::{convert::TryFrom, str::FromStr};
use thiserror::Error;
use tpm2_protocol::data::TpmHt;

#[derive(Debug, Error)]
pub enum HandleError {
    #[error("invalid handle")]
    InvalidHandle,
}

/// Handle types.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandleClass {
    Tpm,
    Vtpm,
}

/// TPM and vTPM handles.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Handle(pub (HandleClass, u32));

impl Handle {
    /// Returns class of the handle.
    #[must_use]
    pub fn class(&self) -> HandleClass {
        self.0 .0
    }

    /// Returns value of the handle.
    #[must_use]
    pub fn value(&self) -> u32 {
        self.0 .1
    }
}

impl std::fmt::Display for Handle {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let value = self.value();
        match self.class() {
            HandleClass::Tpm => write!(f, "tpm:{value:08x}"),
            HandleClass::Vtpm => write!(f, "vtpm:{value:08x}"),
        }
    }
}

impl FromStr for Handle {
    type Err = HandleError;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let (scheme_str, value_str) = s.split_once(':').ok_or(HandleError::InvalidHandle)?;

        let class = match scheme_str {
            "tpm" => HandleClass::Tpm,
            "vtpm" => HandleClass::Vtpm,
            _ => return Err(HandleError::InvalidHandle),
        };

        let value = u32::from_str_radix(value_str, 16).map_err(|_| HandleError::InvalidHandle)?;

        Ok(Handle((class, value)))
    }
}

impl TryFrom<Handle> for TpmHt {
    type Error = HandleError;

    fn try_from(handle: Handle) -> Result<Self, Self::Error> {
        let raw_handle = handle.value();
        let ht_byte = (raw_handle >> 24) as u8;
        TpmHt::try_from(ht_byte).map_err(|()| HandleError::InvalidHandle)
    }
}

#[derive(Debug, Error, PartialEq, Eq)]
pub enum HandlePatternError {
    #[error("handle pattern is not a valid hex string")]
    InvalidHexString,
    #[error("handle pattern has less than eight characters")]
    TooFewDigits,
    #[error("handle pattern has more than one '*'")]
    TooManyAsterisks,
    #[error("handle pattern has more than eight characters")]
    TooManyDigits,
}

/// Pattern matcher for a 32-bit unsigned value represented
/// as a string of exactly eight hex digits.
///
/// Wildcards:
///
/// 1. `?` matches any digit and can be used in place of any digit.
/// 2. `*` matches zero or more digits. Only one `*` is supported per pattern.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HandlePattern(u32, u32);

impl HandlePattern {
    /// Compiles a query string into a `HandlePattern`.
    ///
    /// # Errors
    ///
    /// Returns `HandlePatternError` if the query string is invalid.
    pub fn new(query: &str) -> Result<Self, HandlePatternError> {
        if query == "*" {
            return Ok(Self(0, 0));
        }

        let mut mask: u32 = 0;
        let mut value: u32 = 0;

        let (prefix, suffix) = if let Some((p, s)) = query.split_once('*') {
            if s.contains('*') {
                return Err(HandlePatternError::TooManyAsterisks);
            }
            if p.len() + s.len() > 8 {
                return Err(HandlePatternError::TooManyDigits);
            }
            (p, s)
        } else {
            if query.len() < 8 {
                return Err(HandlePatternError::TooFewDigits);
            }
            if query.len() > 8 {
                return Err(HandlePatternError::TooManyDigits);
            }
            (query, "")
        };

        for (i, c) in prefix.chars().enumerate() {
            let shift = (7 - i) * 4;
            match c.to_digit(16) {
                Some(v) => {
                    mask |= 0xF << shift;
                    value |= v << shift;
                }
                None if c == '?' => {}
                None => return Err(HandlePatternError::InvalidHexString),
            }
        }

        for (i, c) in suffix.chars().rev().enumerate() {
            let shift = i * 4;
            match c.to_digit(16) {
                Some(v) => {
                    mask |= 0xF << shift;
                    value |= v << shift;
                }
                None if c == '?' => {}
                None => return Err(HandlePatternError::InvalidHexString),
            }
        }

        Ok(Self(mask, value))
    }

    /// Checks if a given handle matches the compiled pattern.
    #[must_use]
    pub fn matches(&self, handle: u32) -> bool {
        if self.0 == 0 && self.1 == 0 {
            return true;
        }
        (handle & self.0) == self.1
    }
}