use crate::traits::FloatExt;
#[derive(Debug, Clone, Copy)]
pub struct KyleEquilibrium<T: FloatExt> {
pub beta: T,
pub lambda: T,
pub posterior_variance: T,
pub expected_profit: T,
}
pub fn single_period_kyle<T: FloatExt>(prior_variance: T, noise_variance: T) -> KyleEquilibrium<T> {
assert!(
prior_variance > T::zero(),
"prior_variance must be positive"
);
assert!(
noise_variance > T::zero(),
"noise_variance must be positive"
);
let beta = (noise_variance / prior_variance).sqrt();
let lambda = T::from_f64_fast(0.5) * (prior_variance / noise_variance).sqrt();
let expected_profit = T::from_f64_fast(0.5) * (prior_variance * noise_variance).sqrt();
let posterior_variance = T::from_f64_fast(0.5) * prior_variance;
KyleEquilibrium {
beta,
lambda,
posterior_variance,
expected_profit,
}
}
pub fn multi_period_kyle<T: FloatExt>(
prior_variance: T,
noise_variance_per_round: T,
n_periods: usize,
) -> Vec<KyleEquilibrium<T>> {
assert!(
prior_variance > T::zero(),
"prior_variance must be positive"
);
assert!(
noise_variance_per_round > T::zero(),
"noise_variance_per_round must be positive"
);
assert!(n_periods >= 1, "n_periods must be at least 1");
let gammas = backward_gamma_sequence::<T>(n_periods);
let two = T::from_f64_fast(2.0);
let four = T::from_f64_fast(4.0);
let mut sigma = prior_variance;
let mut out = Vec::with_capacity(n_periods);
for &gamma in &gammas {
let one_minus_gamma = T::one() - gamma;
let one_minus_two_gamma = T::one() - two * gamma;
let lambda_sq = one_minus_two_gamma * sigma
/ (four * one_minus_gamma * one_minus_gamma * noise_variance_per_round);
let lambda = lambda_sq.sqrt();
let beta = one_minus_two_gamma / (two * lambda * one_minus_gamma);
let sigma_next = sigma / (two * one_minus_gamma);
let expected_profit = beta * sigma * (T::one() - lambda * beta);
out.push(KyleEquilibrium {
beta,
lambda,
posterior_variance: sigma_next,
expected_profit,
});
sigma = sigma_next;
}
out
}
fn solve_gamma_cubic<T: FloatExt>(g_next: T) -> T {
let half = T::from_f64_fast(0.5);
let two = T::from_f64_fast(2.0);
let eight = T::from_f64_fast(8.0);
let coef = eight * (T::one() - two * g_next);
let f = |g: T| coef * g * g * (T::one() - g) - (T::one() - two * g);
let mut lo = T::from_f64_fast(1e-15);
let mut hi = half - T::from_f64_fast(1e-15);
for _ in 0..120 {
let mid = (lo + hi) * half;
if f(mid) < T::zero() {
lo = mid;
} else {
hi = mid;
}
}
(lo + hi) * half
}
fn backward_gamma_sequence<T: FloatExt>(n_periods: usize) -> Vec<T> {
let mut gammas = vec![T::zero(); n_periods];
if n_periods >= 2 {
for i in (0..n_periods - 1).rev() {
gammas[i] = solve_gamma_cubic(gammas[i + 1]);
}
}
gammas
}
#[cfg(test)]
mod tests {
use super::*;
fn approx(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() <= tol
}
#[test]
fn single_period_satisfies_beta_lambda_half() {
let eq = single_period_kyle(0.04_f64, 1.0);
assert!(approx(eq.beta * eq.lambda, 0.5, 1e-12));
}
#[test]
fn single_period_lambda_scales_correctly() {
let eq = single_period_kyle(1.0_f64, 1.0);
assert!(approx(eq.lambda, 0.5, 1e-12));
assert!(approx(eq.beta, 1.0, 1e-12));
assert!(approx(eq.expected_profit, 0.5, 1e-12));
}
#[test]
fn single_period_posterior_halves_prior() {
let eq = single_period_kyle(2.5_f64, 0.1);
assert!(approx(eq.posterior_variance, 1.25, 1e-12));
}
#[test]
fn multi_period_returns_one_per_round() {
let eqs = multi_period_kyle(1.0_f64, 1.0, 5);
assert_eq!(eqs.len(), 5);
for eq in &eqs {
assert!(eq.lambda > 0.0);
assert!(eq.beta > 0.0);
assert!(eq.posterior_variance > 0.0);
}
}
#[test]
fn multi_period_posterior_decreases() {
let eqs = multi_period_kyle(1.0_f64, 1.0, 8);
let mut last = f64::INFINITY;
for eq in &eqs {
assert!(eq.posterior_variance < last);
last = eq.posterior_variance;
}
}
#[test]
fn multi_period_one_round_matches_single_period() {
let prior = 0.04_f64;
let noise = 1.0;
let single = single_period_kyle(prior, noise);
let multi = multi_period_kyle(prior, noise, 1);
assert_eq!(multi.len(), 1);
assert!(
approx(multi[0].lambda, single.lambda, 1e-12),
"lambda mismatch: multi={} single={}",
multi[0].lambda,
single.lambda
);
assert!(
approx(multi[0].beta, single.beta, 1e-12),
"beta mismatch: multi={} single={}",
multi[0].beta,
single.beta
);
assert!(
approx(
multi[0].posterior_variance,
single.posterior_variance,
1e-12
),
"posterior_variance mismatch: multi={} single={}",
multi[0].posterior_variance,
single.posterior_variance
);
assert!(
approx(multi[0].expected_profit, single.expected_profit, 1e-12),
"expected_profit mismatch: multi={} single={}",
multi[0].expected_profit,
single.expected_profit
);
}
#[test]
fn multi_period_two_round_matches_canonical() {
let eqs = multi_period_kyle(1.0_f64, 1.0, 2);
assert_eq!(eqs.len(), 2);
assert!(
approx(eqs[0].lambda, 0.4617, 1e-3),
"λ_1 = {}",
eqs[0].lambda
);
assert!(approx(eqs[0].beta, 0.6669, 1e-3), "β_1 = {}", eqs[0].beta);
assert!(
approx(eqs[0].posterior_variance, 0.6920, 1e-3),
"Σ_1 = {}",
eqs[0].posterior_variance
);
assert!(
approx(eqs[1].lambda, 0.4159, 1e-3),
"λ_2 = {}",
eqs[1].lambda
);
assert!(approx(eqs[1].beta, 1.2022, 1e-3), "β_2 = {}", eqs[1].beta);
assert!(
approx(eqs[1].posterior_variance, 0.3460, 1e-3),
"Σ_2 = {}",
eqs[1].posterior_variance
);
assert!(
approx(eqs[1].lambda * eqs[1].beta, 0.5, 1e-12),
"terminal β_N·λ_N must equal 1/2, got {}",
eqs[1].lambda * eqs[1].beta
);
}
}