def sample_bernoulli_float(prob: T, constant_time: bool) -> bool:
if prob == 1: return True
max_coin_flips = usize.exact_int_cast(T.EXPONENT_BIAS) + usize.exact_int_cast(
T.MANTISSA_BITS
)
buffer_len = max_coin_flips.inf_div(8)
first_heads_index = sample_geometric_buffer( buffer_len, constant_time
)
if first_heads_index is None: return False
leading_zeroes = (
T.EXPONENT_BIAS - 1 - prob.raw_exponent()
)
if first_heads_index < leading_zeros: return False
if first_heads_index == leading_zeroes: return prob.raw_exponent() != 0
if first_heads_index > leading_zeroes + T.MANTISSA_BITS: return False
mask = 1 << (T.MANTISSA_BITS + leading_zeroes - first_heads_index)
return (prob.to_bits() & mask) != 0