from __future__ import annotations
import base64
import hashlib
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
__all__ = [
"VaultSigner",
"VaultSignerError",
"AwsKmsSigner",
"GcpKmsSigner",
"HashiCorpVaultSigner",
"AzureKeyVaultSigner",
"LocalFileSigner",
"VaultSignerConfig",
"build_signer",
]
class VaultSignerError(Exception):
def __init__(self, provider: str, message: str) -> None:
super().__init__(f"[{provider}] {message}")
self.provider = provider
class VaultSigner(ABC):
@abstractmethod
def sign(self, payload: bytes) -> bytes:
@abstractmethod
def verifying_key_bytes(self) -> bytes:
def verifying_key_hex(self) -> str:
return self.verifying_key_bytes().hex()
def key_fingerprint(self) -> str:
digest = hashlib.sha256(self.verifying_key_bytes()).hexdigest()
return digest[:16]
@dataclass
class VaultSignerConfig:
provider: str
key_id: str = ""
region: str = ""
project_id: str = ""
location_id: str = ""
key_ring_id: str = ""
crypto_key_id: str = ""
key_version_id: str = "1"
vault_addr: str = ""
vault_token: str = ""
key_name: str = ""
vault_namespace: str = ""
vault_ca_cert: Optional[str] = None
tenant_id: str = ""
client_id: str = ""
client_secret: str = ""
vault_name: str = ""
key_path: str = ""
extra: Dict[str, Any] = field(default_factory=dict)
def build_signer(config: VaultSignerConfig) -> VaultSigner:
p = config.provider
if p == "aws_kms":
return AwsKmsSigner(key_id=config.key_id, region=config.region)
if p == "gcp_kms":
return GcpKmsSigner(
project_id=config.project_id,
location_id=config.location_id,
key_ring_id=config.key_ring_id,
crypto_key_id=config.crypto_key_id,
key_version_id=config.key_version_id,
)
if p == "hashicorp_vault":
return HashiCorpVaultSigner(
vault_addr=config.vault_addr,
key_name=config.key_name,
token=config.vault_token,
namespace=config.vault_namespace or None,
ca_cert_path=config.vault_ca_cert,
)
if p == "azure_key_vault":
return AzureKeyVaultSigner(
vault_name=config.vault_name,
key_name=config.key_name,
tenant_id=config.tenant_id,
client_id=config.client_id,
client_secret=config.client_secret,
)
if p == "local_file":
return LocalFileSigner(path=config.key_path)
raise VaultSignerError("factory", f"Unknown provider: '{p}'. Supported: aws_kms, gcp_kms, hashicorp_vault, azure_key_vault, local_file")
class AwsKmsSigner(VaultSigner):
def __init__(
self,
*,
key_id: str,
region: str,
profile: Optional[str] = None,
) -> None:
self._key_id = key_id
self._region = region
self._profile = profile
self._vk_cache: Optional[bytes] = None
def _client(self) -> Any:
try:
import boto3
except ImportError as exc:
raise ImportError("boto3 is required: pip install boto3") from exc
session = boto3.Session(profile_name=self._profile)
return session.client("kms", region_name=self._region)
def sign(self, payload: bytes) -> bytes:
kms = self._client()
try:
resp = kms.generate_mac(
KeyId=self._key_id,
Message=payload,
MacAlgorithm="HMAC_SHA_256",
)
except Exception as exc:
raise VaultSignerError("aws_kms", str(exc)) from exc
raw_mac = resp["Mac"]
private_scalar = hashlib.sha512(raw_mac).digest()[:32]
try:
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
except ImportError as exc:
raise ImportError("cryptography is required: pip install cryptography") from exc
private_key = Ed25519PrivateKey.from_private_bytes(private_scalar)
return private_key.sign(payload)
def verifying_key_bytes(self) -> bytes:
if self._vk_cache is not None:
return self._vk_cache
kms = self._client()
try:
resp = kms.list_resource_tags(KeyId=self._key_id)
except Exception as exc:
raise VaultSignerError("aws_kms", str(exc)) from exc
for tag in resp.get("Tags", []):
if tag.get("TagKey") == "a1_verifying_key":
vk_hex = tag["TagValue"]
self._vk_cache = bytes.fromhex(vk_hex)
return self._vk_cache
raise VaultSignerError(
"aws_kms",
"Tag 'a1_verifying_key' not found on KMS key. "
"Run `a1 kms bootstrap --provider aws --key-id <id>` to initialize.",
)
class GcpKmsSigner(VaultSigner):
def __init__(
self,
*,
project_id: str,
location_id: str,
key_ring_id: str,
crypto_key_id: str,
key_version_id: str = "1",
) -> None:
self._project_id = project_id
self._location_id = location_id
self._key_ring_id = key_ring_id
self._crypto_key_id = crypto_key_id
self._key_version_id = key_version_id
self._vk_cache: Optional[bytes] = None
def _key_version_name(self) -> str:
return (
f"projects/{self._project_id}/locations/{self._location_id}/"
f"keyRings/{self._key_ring_id}/cryptoKeys/{self._crypto_key_id}/"
f"cryptoKeyVersions/{self._key_version_id}"
)
def _client(self) -> Any:
try:
from google.cloud import kms
except ImportError as exc:
raise ImportError(
"google-cloud-kms is required: pip install google-cloud-kms"
) from exc
return kms.KeyManagementServiceClient()
def sign(self, payload: bytes) -> bytes:
client = self._client()
digest = hashlib.sha256(payload).digest()
try:
from google.cloud.kms import Digest as KmsDigest
resp = client.asymmetric_sign(
request={
"name": self._key_version_name(),
"digest": KmsDigest(sha256=digest),
}
)
except Exception as exc:
raise VaultSignerError("gcp_kms", str(exc)) from exc
return bytes(resp.signature)
def verifying_key_bytes(self) -> bytes:
if self._vk_cache is not None:
return self._vk_cache
client = self._client()
try:
resp = client.get_public_key({"name": self._key_version_name()})
except Exception as exc:
raise VaultSignerError("gcp_kms", str(exc)) from exc
pem = resp.pem.encode()
try:
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
except ImportError as exc:
raise ImportError("cryptography is required: pip install cryptography") from exc
pub = load_pem_public_key(pem)
if not isinstance(pub, Ed25519PublicKey):
raise VaultSignerError("gcp_kms", "Key is not Ed25519. Use an Ed25519 SIGN_ASYMMETRIC key.")
raw = pub.public_bytes_raw()
self._vk_cache = raw
return raw
class HashiCorpVaultSigner(VaultSigner):
def __init__(
self,
*,
vault_addr: str,
key_name: str,
token: Optional[str] = None,
namespace: Optional[str] = None,
ca_cert_path: Optional[str] = None,
) -> None:
self._vault_addr = vault_addr
self._key_name = key_name
self._token = token
self._namespace = namespace
self._ca_cert_path = ca_cert_path
self._vk_cache: Optional[bytes] = None
def _client(self) -> Any:
try:
import hvac
except ImportError as exc:
raise ImportError("hvac is required: pip install hvac") from exc
import os
token = self._token or os.environ.get("VAULT_TOKEN", "")
kwargs: Dict[str, Any] = {"url": self._vault_addr, "token": token}
if self._ca_cert_path:
kwargs["verify"] = self._ca_cert_path
if self._namespace:
kwargs["namespace"] = self._namespace
return hvac.Client(**kwargs)
def sign(self, payload: bytes) -> bytes:
client = self._client()
b64_input = base64.b64encode(payload).decode()
try:
resp = client.secrets.transit.sign_data(
name=self._key_name,
hash_input=b64_input,
hash_algorithm="sha2-256",
signature_algorithm="pkcs1v15",
prehashed=False,
marshaling_algorithm="raw",
)
except Exception as exc:
raise VaultSignerError("hashicorp_vault", str(exc)) from exc
sig_str: str = resp["data"]["signature"]
sig_b64 = sig_str.removeprefix("vault:v1:")
return base64.b64decode(sig_b64)
def verifying_key_bytes(self) -> bytes:
if self._vk_cache is not None:
return self._vk_cache
client = self._client()
try:
resp = client.secrets.transit.read_key(name=self._key_name)
except Exception as exc:
raise VaultSignerError("hashicorp_vault", str(exc)) from exc
keys = resp["data"].get("keys", {})
latest = str(resp["data"].get("latest_version", 1))
key_data = keys.get(latest, {})
pub_key_b64 = key_data.get("public_key", "")
if not pub_key_b64:
raise VaultSignerError(
"hashicorp_vault",
f"No public key found for key '{self._key_name}' version {latest}. "
"Ensure the Transit key type is 'ed25519'.",
)
raw = base64.b64decode(pub_key_b64)
if len(raw) == 32:
self._vk_cache = raw
else:
try:
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
except ImportError as exc:
raise ImportError("cryptography is required: pip install cryptography") from exc
pub = load_pem_public_key(raw)
self._vk_cache = pub.public_bytes_raw()
return self._vk_cache
class AzureKeyVaultSigner(VaultSigner):
def __init__(
self,
*,
vault_name: str,
key_name: str,
tenant_id: str,
client_id: str,
client_secret: str,
key_version: Optional[str] = None,
) -> None:
self._vault_url = f"https://{vault_name}.vault.azure.net"
self._key_name = key_name
self._key_version = key_version
self._tenant_id = tenant_id
self._client_id = client_id
self._client_secret = client_secret
self._vk_cache: Optional[bytes] = None
def _key_client(self) -> Any:
try:
from azure.identity import ClientSecretCredential
from azure.keyvault.keys import KeyClient
from azure.keyvault.keys.crypto import CryptographyClient
except ImportError as exc:
raise ImportError(
"Azure SDKs required: pip install azure-keyvault-keys azure-identity"
) from exc
cred = ClientSecretCredential(
tenant_id=self._tenant_id,
client_id=self._client_id,
client_secret=self._client_secret,
)
return KeyClient(vault_url=self._vault_url, credential=cred)
def sign(self, payload: bytes) -> bytes:
try:
from azure.identity import ClientSecretCredential
from azure.keyvault.keys.crypto import CryptographyClient, SignatureAlgorithm
except ImportError as exc:
raise ImportError(
"Azure SDKs required: pip install azure-keyvault-keys azure-identity"
) from exc
cred = ClientSecretCredential(
tenant_id=self._tenant_id,
client_id=self._client_id,
client_secret=self._client_secret,
)
key_client = self._key_client()
try:
key = key_client.get_key(self._key_name, version=self._key_version)
crypto_client = CryptographyClient(key, credential=cred)
digest = hashlib.sha256(payload).digest()
result = crypto_client.sign(SignatureAlgorithm.ed25519, digest)
except Exception as exc:
raise VaultSignerError("azure_key_vault", str(exc)) from exc
return result.signature
def verifying_key_bytes(self) -> bytes:
if self._vk_cache is not None:
return self._vk_cache
try:
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
except ImportError as exc:
raise ImportError("cryptography is required: pip install cryptography") from exc
key_client = self._key_client()
try:
key = key_client.get_key(self._key_name, version=self._key_version)
except Exception as exc:
raise VaultSignerError("azure_key_vault", str(exc)) from exc
if key.key.x is None:
raise VaultSignerError(
"azure_key_vault",
f"Key '{self._key_name}' has no Ed25519 public key material. "
"Ensure the key type is OKP (EdDSA / Ed25519).",
)
self._vk_cache = bytes(key.key.x)
return self._vk_cache
class LocalFileSigner(VaultSigner):
def __init__(
self,
*,
path: str,
password: Optional[bytes] = None,
) -> None:
self._path = path
self._password = password
self._private_key: Any = None
self._vk_cache: Optional[bytes] = None
def _ensure_loaded(self) -> None:
if self._private_key is not None:
return
try:
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
except ImportError as exc:
raise ImportError("cryptography is required: pip install cryptography") from exc
with open(self._path, "rb") as f:
raw = f.read()
if raw.strip().startswith(b"-----"):
self._private_key = load_pem_private_key(raw, password=self._password)
else:
hex_str = raw.strip().decode()
scalar = bytes.fromhex(hex_str)
self._private_key = Ed25519PrivateKey.from_private_bytes(scalar)
def sign(self, payload: bytes) -> bytes:
self._ensure_loaded()
try:
return self._private_key.sign(payload)
except Exception as exc:
raise VaultSignerError("local_file", str(exc)) from exc
def verifying_key_bytes(self) -> bytes:
if self._vk_cache is not None:
return self._vk_cache
self._ensure_loaded()
pub = self._private_key.public_key()
try:
raw = pub.public_bytes_raw()
except AttributeError:
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
raw = pub.public_bytes(Encoding.Raw, PublicFormat.Raw)
self._vk_cache = raw
return raw