pub mod config_contracts {
pub fn validate_size(size: usize, max: usize) -> bool {
size <= max
}
pub fn validate_index(index: usize, len: usize) -> bool {
index < len
}
pub fn validated_len(data: &[u8]) -> usize {
debug_assert!(!data.is_empty(), "data must not be empty");
data.len()
}
}
pub mod numeric_contracts {
pub fn checked_add(a: u64, b: u64) -> Option<u64> {
a.checked_add(b)
}
pub fn is_valid_float(val: f64) -> bool {
val.is_finite()
}
pub fn normalize(val: f64, min: f64, max: f64) -> f64 {
debug_assert!(max > min, "max must be greater than min");
((val - min) / (max - min)).clamp(0.0, 1.0)
}
}
#[cfg(verus)]
mod verus_specs {
use builtin::*;
use builtin_macros::*;
verus! {
#[requires(seq_len >= 0)]
#[ensures(result == (seq_len < max_seq_len))]
#[recommends(max_seq_len >= 2048)]
fn verify_kv_cache_capacity(seq_len: u64, max_seq_len: u64) -> bool {
seq_len < max_seq_len
}
#[requires(num_layers > 0 && num_heads > 0 && head_dim > 0)]
#[ensures(result == num_layers * num_heads * head_dim * max_seq_len * 2)]
fn verify_kv_cache_size(
num_layers: u64, num_heads: u64, head_dim: u64, max_seq_len: u64,
) -> u64 {
num_layers * num_heads * head_dim * max_seq_len * 2
}
#[requires(pos >= 0)]
#[ensures(result == pos + 1)]
#[invariant(pos < max_seq_len)]
fn verify_kv_cache_advance(pos: u64, max_seq_len: u64) -> u64 {
pos + 1
}
#[requires(generated >= 0)]
#[ensures(result == (generated < max_tokens))]
#[recommends(max_tokens <= 4096)]
fn verify_token_budget(generated: u64, max_tokens: u64) -> bool {
generated < max_tokens
}
#[requires(token_id >= 0)]
#[ensures(result == (token_id < vocab_size))]
fn verify_token_id(token_id: u64, vocab_size: u64) -> bool {
token_id < vocab_size
}
#[requires(prompt_len > 0)]
#[ensures(result == prompt_len + max_new_tokens)]
#[recommends(prompt_len + max_new_tokens <= 131072)]
fn verify_total_sequence_len(prompt_len: u64, max_new_tokens: u64) -> u64 {
prompt_len + max_new_tokens
}
#[requires(seq_len > 0 && head_dim > 0)]
#[ensures(result == seq_len * head_dim)]
fn verify_attention_buffer_size(seq_len: u64, head_dim: u64) -> u64 {
seq_len * head_dim
}
#[requires(num_q_heads > 0 && num_kv_heads > 0)]
#[ensures(result == num_q_heads / num_kv_heads)]
#[invariant(num_q_heads % num_kv_heads == 0)]
fn verify_gqa_ratio(num_q_heads: u64, num_kv_heads: u64) -> u64 {
num_q_heads / num_kv_heads
}
#[requires(head_dim > 0)]
#[ensures(result > 0)]
fn verify_attention_scale(head_dim: u64) -> u64 {
head_dim }
#[requires(logits_len > 0)]
#[ensures(result < logits_len)]
fn verify_sampled_token(result_idx: u64, logits_len: u64) -> u64 { result_idx }
#[requires(top_k > 0)]
#[ensures(result <= vocab_size)]
#[recommends(top_k <= 100)]
fn verify_top_k_bounds(top_k: u64, vocab_size: u64) -> u64 {
if top_k > vocab_size { vocab_size } else { top_k }
}
#[requires(top_p > 0)]
#[ensures(result == (top_p <= 100))]
#[recommends(top_p >= 80 && top_p <= 100)]
fn verify_top_p_bounds(top_p: u64) -> bool {
top_p <= 100
}
#[requires(num_tensors > 0)]
#[ensures(result == (loaded == num_tensors))]
#[invariant(loaded <= num_tensors)]
fn verify_model_load_complete(loaded: u64, num_tensors: u64) -> bool {
loaded == num_tensors
}
#[requires(expected_size > 0)]
#[ensures(result == (actual_size == expected_size))]
fn verify_tensor_size_match(actual_size: u64, expected_size: u64) -> bool {
actual_size == expected_size
}
#[requires(elapsed_ms > 0)]
#[ensures(result == (tokens_generated * 1000) / elapsed_ms)]
#[recommends(result >= 30)]
fn verify_tokens_per_second(tokens_generated: u64, elapsed_ms: u64) -> u64 {
(tokens_generated * 1000) / elapsed_ms
}
#[requires(pos >= 0)]
#[ensures(result == (pos < max_pos))]
#[invariant(max_pos > 0)]
fn verify_rope_position(pos: u64, max_pos: u64) -> bool {
pos < max_pos
}
#[requires(dim > 0 && dim % 2 == 0)]
#[ensures(result == dim / 2)]
#[decreases(dim)]
fn verify_rope_pairs(dim: u64) -> u64 {
dim / 2
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_size() {
assert!(config_contracts::validate_size(5, 10));
assert!(!config_contracts::validate_size(11, 10));
assert!(config_contracts::validate_size(10, 10));
}
#[test]
fn test_validate_index() {
assert!(config_contracts::validate_index(0, 5));
assert!(config_contracts::validate_index(4, 5));
assert!(!config_contracts::validate_index(5, 5));
}
#[test]
fn test_validated_len() {
assert_eq!(config_contracts::validated_len(&[1, 2, 3]), 3);
}
#[test]
fn test_checked_add() {
assert_eq!(numeric_contracts::checked_add(1, 2), Some(3));
assert_eq!(numeric_contracts::checked_add(u64::MAX, 1), None);
}
#[test]
fn test_is_valid_float() {
assert!(numeric_contracts::is_valid_float(1.0));
assert!(!numeric_contracts::is_valid_float(f64::NAN));
assert!(!numeric_contracts::is_valid_float(f64::INFINITY));
}
#[test]
fn test_normalize() {
let result = numeric_contracts::normalize(5.0, 0.0, 10.0);
assert!((result - 0.5).abs() < f64::EPSILON);
assert!((numeric_contracts::normalize(0.0, 0.0, 10.0)).abs() < f64::EPSILON);
assert!((numeric_contracts::normalize(10.0, 0.0, 10.0) - 1.0).abs() < f64::EPSILON);
}
}
#[cfg(kani)]
mod kani_proofs {
#[kani::proof]
fn verify_config_bounds() {
let val: u32 = kani::any();
kani::assume(val <= 1000);
assert!(val <= 1000);
}
#[kani::proof]
fn verify_index_safety() {
let len: usize = kani::any();
kani::assume(len > 0 && len <= 1024);
let idx: usize = kani::any();
kani::assume(idx < len);
assert!(idx < len);
}
#[kani::proof]
fn verify_no_overflow_add() {
let a: u32 = kani::any();
let b: u32 = kani::any();
kani::assume(a <= 10000);
kani::assume(b <= 10000);
let result = a.checked_add(b);
assert!(result.is_some());
}
#[kani::proof]
fn verify_no_overflow_mul() {
let a: u32 = kani::any();
let b: u32 = kani::any();
kani::assume(a <= 1000);
kani::assume(b <= 1000);
let result = a.checked_mul(b);
assert!(result.is_some());
}
#[kani::proof]
fn verify_division_nonzero() {
let numerator: u64 = kani::any();
let denominator: u64 = kani::any();
kani::assume(denominator > 0);
let result = numerator / denominator;
assert!(result <= numerator);
}
}