from __future__ import annotations
__version__ = '0.4.1'
__author__ = 'Daniel Kahn Gillmor <dkg@fifthhorseman.net>'
__all__ = [
'ML_KEM_512',
'ML_KEM_768',
'ML_KEM_1024',
'Ciphertext',
'EncapsulationKey',
'DecapsulationKey',
'Seed',
]
import ctypes
import ctypes.util
import enum
import secrets
from typing import Tuple, Dict, Any, Union, Optional
from abc import ABC
import sys
class _SharedSecret(ctypes.Structure):
_fields_ = [('data', ctypes.c_uint8 * 32)]
class _Seed(ctypes.Structure):
_fields_ = [('data', ctypes.c_uint8 * 64)]
class Err(enum.IntEnum):
OK = 0
NULL_PTR_ERROR = 1
SERIALIZATION_ERROR = 2
DESERIALIZATION_ERROR = 3
KEYGEN_ERROR = 4
ENCAPSULATION_ERROR = 5
DECAPSULATION_ERROR = 6
class Seed():
def __init__(self, data: Optional[bytes] = None) -> None:
self._seed = _Seed()
if data is None:
data = secrets.token_bytes(len(self._seed.data))
if len(data) != len(self._seed.data):
raise ValueError(f"Expected {len(self._seed.data)} bytes, "
f"got {len(data)}.")
for i in range(len(data)):
self._seed.data[i] = data[i]
def __repr__(self) -> str:
return '<ML-KEM Seed>'
def __bytes__(self) -> bytes:
return bytes(self._seed.data)
def keygen(self, strength: int) -> Tuple[EncapsulationKey, DecapsulationKey]:
for kt in ML_KEM_512, ML_KEM_768, ML_KEM_1024:
if kt._strength == strength:
return kt.keygen(self)
raise Exception(f"Unknown strength: {strength}, must be 512, 768, or 1024.")
class Ciphertext():
def __init__(self, data: Union[bytes, int]) -> None:
if isinstance(data, bytes):
self._strength = _ML_KEM.strength_from_length('CT_SIZE', len(data))
elif isinstance(data, int):
self._strength = data
else:
raise Exception("Initialize ML-KEM Ciphertext object with "
f"bytes or a strength level, not {type(data)}")
self._ffi = _ML_KEM.strength(self._strength)
self._ct = self._ffi['Ciphertext']()
if isinstance(data, bytes):
self._set(data)
def __repr__(self) -> str:
return f'<ML-KEM-{self._strength} Ciphertext>'
def __bytes__(self) -> bytes:
return bytes(self._ct.data)
def _set(self, data: bytes) -> None:
if len(data) != len(self._ct.data):
raise ValueError(f"Expected {len(self._ct.data)} bytes, "
f"got {len(data)}")
for i in range(len(data)):
self._ct.data[i] = data[i]
class EncapsulationKey():
def __init__(self, data: Union[bytes, int]) -> None:
if isinstance(data, bytes):
self._strength = _ML_KEM.strength_from_length('EK_SIZE', len(data))
elif isinstance(data, int):
self._strength = data
else:
raise Exception("Initialize ML-KEM Encapsulation Key with "
f"bytes or a strength level, not {type(data)}")
self._ffi = _ML_KEM.strength(self._strength)
self._ek = self._ffi['EncapsKey']()
if isinstance(data, bytes):
self._set(data)
def __repr__(self) -> str:
return f'<ML-KEM-{self._strength} Encapsulation Key>'
def __bytes__(self) -> bytes:
return bytes(self._ek.data)
def _set(self, data: bytes) -> None:
if len(data) != len(self._ek.data):
raise ValueError(f"Expected {len(self._ek.data)} bytes, "
f"got {len(data)}")
for i in range(len(data)):
self._ek.data[i] = data[i]
def encaps(self) -> Tuple[Ciphertext, bytes]:
ct = Ciphertext(self._strength)
ss = _SharedSecret()
ret = Err(self._ffi['encaps'](ctypes.byref(self._ek),
ctypes.byref(ct._ct),
ctypes.byref(ss)))
if ret is not Err.OK:
raise Exception(f"ml_kem_{self._strength}_encaps() "
f"returned {ret} ({ret.name})")
return (ct, bytes(ss.data))
class DecapsulationKey():
def __init__(self, data: Union[bytes, int]) -> None:
if isinstance(data, bytes):
self._strength = _ML_KEM.strength_from_length('DK_SIZE', len(data))
elif isinstance(data, int):
self._strength = data
else:
raise Exception("Initialize ML-KEM Encapsulation Key with bytes "
f"or a strength level, not {type(data)}")
self._ffi = _ML_KEM.strength(self._strength)
self._dk = self._ffi['DecapsKey']()
if isinstance(data, bytes):
self._set(data)
def __repr__(self) -> str:
return f'<ML-KEM-{self._strength} Decapsulation Key>'
def __bytes__(self) -> bytes:
return bytes(self._dk.data)
def _set(self, data: bytes) -> None:
if len(data) != len(self._dk.data):
raise ValueError(f"Expected {len(self._dk.data)} bytes, "
f"got {len(data)}")
for i in range(len(data)):
self._dk.data[i] = data[i]
def decaps(self, ct: Ciphertext) -> bytes:
if self._strength != ct._strength:
raise Exception(f"Cannot decapsulate {ct} with {self}")
ss = _SharedSecret()
ret = Err(self._ffi['decaps'](ctypes.byref(self._dk),
ctypes.byref(ct._ct),
ctypes.byref(ss)))
if ret is not Err.OK:
raise Exception(f"ml_kem_{self._strength}_decaps() "
f"returned {ret} ({ret.name})")
return bytes(ss.data)
class _ML_KEM():
params: Dict[int, Dict[str, int]] = {
512: {
'EK_SIZE': 800,
'DK_SIZE': 1632,
'CT_SIZE': 768,
},
768: {
'EK_SIZE': 1184,
'DK_SIZE': 2400,
'CT_SIZE': 1088,
},
1024: {
'EK_SIZE': 1568,
'DK_SIZE': 3168,
'CT_SIZE': 1568,
},
}
lib = ctypes.CDLL(ctypes.util.find_library('fips203'))
if not hasattr(lib, 'ml_kem_512_keygen'):
if sys.platform == 'darwin':
lib = ctypes.CDLL('../../target/debug/libfips203.dylib')
else:
lib = ctypes.CDLL("../../target/debug/libfips203.so")
ffi: Dict[int, Dict[str, Any]] = {}
@classmethod
def strength(cls, level: int) -> Dict[str, Any]:
if level not in cls.ffi:
class _EncapsKey(ctypes.Structure):
_fields_ = [('data', ctypes.c_uint8 *
cls.params[level]['EK_SIZE'])]
class _DecapsKey(ctypes.Structure):
_fields_ = [('data', ctypes.c_uint8 *
cls.params[level]['DK_SIZE'])]
class _Ciphertext(ctypes.Structure):
_fields_ = [('data', ctypes.c_uint8 *
cls.params[level]['CT_SIZE'])]
ffi: Dict[str, Any] = {}
ffi['keygen'] = cls.lib[f'ml_kem_{level}_keygen']
ffi['keygen'].argtypes = [ctypes.POINTER(_EncapsKey),
ctypes.POINTER(_DecapsKey)]
ffi['keygen'].restype = ctypes.c_uint8
ffi['keygen_from_seed'] = cls.lib[f'ml_kem_{level}_keygen_from_seed']
ffi['keygen_from_seed'].argtypes = [ctypes.POINTER(_Seed),
ctypes.POINTER(_EncapsKey),
ctypes.POINTER(_DecapsKey)]
ffi['keygen_from_seed'].restype = ctypes.c_uint8
ffi['encaps'] = cls.lib[f'ml_kem_{level}_encaps']
ffi['encaps'].argtypes = [ctypes.POINTER(_EncapsKey),
ctypes.POINTER(_Ciphertext),
ctypes.POINTER(_SharedSecret)]
ffi['encaps'].restype = ctypes.c_uint8
ffi['decaps'] = cls.lib[f'ml_kem_{level}_decaps']
ffi['decaps'].argtypes = [ctypes.POINTER(_DecapsKey),
ctypes.POINTER(_Ciphertext),
ctypes.POINTER(_SharedSecret)]
ffi['decaps'].restype = ctypes.c_uint8
ffi['EncapsKey'] = _EncapsKey
ffi['DecapsKey'] = _DecapsKey
ffi['Ciphertext'] = _Ciphertext
cls.ffi[level] = ffi
return cls.ffi[level]
@classmethod
def strength_from_length(cls, object_type: str, object_len: int) -> int:
for strength in cls.params:
if cls.params[strength][object_type] == object_len:
return strength
raise Exception(f"No ML-KEM parameter set has {object_type} "
f"of {object_len} bytes")
@classmethod
def _keygen(cls, strength: int) -> Tuple[EncapsulationKey,
DecapsulationKey]:
ek = EncapsulationKey(strength)
dk = DecapsulationKey(strength)
ret = Err(cls.strength(strength)['keygen'](ctypes.byref(ek._ek),
ctypes.byref(dk._dk)))
if ret is not Err.OK:
raise Exception(f"ml_kem_{strength}_keygen() returned "
f"{ret} ({ret.name})")
return (ek, dk)
@classmethod
def _keygen_from_seed(cls, strength: int, seed: Seed) -> Tuple[EncapsulationKey,
DecapsulationKey]:
ek = EncapsulationKey(strength)
dk = DecapsulationKey(strength)
ret = Err(cls.strength(strength)['keygen_from_seed'](
ctypes.byref(seed._seed),
ctypes.byref(ek._ek),
ctypes.byref(dk._dk)
))
if ret is not Err.OK:
raise Exception(f"ml_kem_{strength}_keygen() returned "
f"{ret} ({ret.name})")
return (ek, dk)
class ML_KEM(ABC):
_strength: int
EK_SIZE: int
DK_SIZE: int
CT_SIZE: int
SS_SIZE: int = 32
SEED_SIZE: int = 64
@classmethod
def keygen(cls, seed: Optional[Seed] = None) -> Tuple[EncapsulationKey, DecapsulationKey]:
if seed is None:
return _ML_KEM._keygen(cls._strength)
else:
return _ML_KEM._keygen_from_seed(cls._strength, seed)
class ML_KEM_512(ML_KEM):
_strength: int = 512
EK_SIZE: int = 800
DK_SIZE: int = 1632
CT_SIZE: int = 768
class ML_KEM_768(ML_KEM):
_strength: int = 768
EK_SIZE: int = 1184
DK_SIZE: int = 2400
CT_SIZE: int = 1088
class ML_KEM_1024(ML_KEM):
_strength: int = 1024
EK_SIZE: int = 1568
DK_SIZE: int = 3168
CT_SIZE: int = 1568