use std::collections::HashMap;
use std::time::{Duration, SystemTime};
use crate::cc_tier::{CCTier, TierAttestation};
pub const AI_REWARD_POOL_SHARE: f64 = 0.10;
pub const VALIDATOR_REWARD_SHARE: f64 = 0.90;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum ModelingLevel {
InferenceLight = 1,
InferenceStandard = 2,
InferenceHeavy = 3,
Training = 4,
Specialized = 5,
}
impl ModelingLevel {
pub fn as_str(&self) -> &'static str {
match self {
Self::InferenceLight => "Inference-Light",
Self::InferenceStandard => "Inference-Standard",
Self::InferenceHeavy => "Inference-Heavy",
Self::Training => "Training",
Self::Specialized => "Specialized",
}
}
pub fn base_reward_multiplier(&self) -> f64 {
match self {
Self::InferenceLight => 0.5,
Self::InferenceStandard => 1.0,
Self::InferenceHeavy => 1.5,
Self::Training => 2.0,
Self::Specialized => 2.5,
}
}
pub fn min_vram_gb(&self) -> u64 {
match self {
Self::InferenceLight => 8, Self::InferenceStandard => 24, Self::InferenceHeavy => 80, Self::Training => 48, Self::Specialized => 16, }
}
pub fn from_u8(v: u8) -> Option<Self> {
match v {
1 => Some(Self::InferenceLight),
2 => Some(Self::InferenceStandard),
3 => Some(Self::InferenceHeavy),
4 => Some(Self::Training),
5 => Some(Self::Specialized),
_ => None,
}
}
}
impl std::fmt::Display for ModelingLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone)]
pub struct AIProvider {
pub provider_id: String,
pub attestation: Option<TierAttestation>,
pub max_modeling_level: ModelingLevel,
pub current_modeling_level: Option<ModelingLevel>,
pub stake_lux: u64,
pub last_heartbeat: SystemTime,
pub consecutive_epochs: u64,
pub tasks_this_epoch: u64,
pub total_tasks_completed: u64,
pub reputation_score: f64,
}
impl AIProvider {
pub fn new(provider_id: String, stake_lux: u64, max_level: ModelingLevel) -> Self {
Self {
provider_id,
attestation: None,
max_modeling_level: max_level,
current_modeling_level: None,
stake_lux,
last_heartbeat: SystemTime::now(),
consecutive_epochs: 0,
tasks_this_epoch: 0,
total_tasks_completed: 0,
reputation_score: 0.5, }
}
pub fn is_online(&self, max_heartbeat_age: Duration) -> bool {
match self.last_heartbeat.elapsed() {
Ok(elapsed) => elapsed < max_heartbeat_age,
Err(_) => false,
}
}
pub fn effective_tier(&self) -> CCTier {
self.attestation
.as_ref()
.filter(|a| a.is_valid())
.map(|a| a.tier)
.unwrap_or(CCTier::Tier4Standard)
}
pub fn reward_weight(&self) -> f64 {
let tier = self.effective_tier();
let tier_mult = tier.reward_multiplier();
let model_mult = self.max_modeling_level.base_reward_multiplier();
let stake_weight = if self.stake_lux > 1000 {
((self.stake_lux as f64 / 1000.0).sqrt()).min(10.0)
} else {
1.0
};
let uptime_bonus = 1.0 + (self.consecutive_epochs as f64 / 1000.0).min(0.5);
let rep_bonus = 0.8 + (self.reputation_score * 0.4);
tier_mult * model_mult * stake_weight * uptime_bonus * rep_bonus
}
}
#[derive(Debug, Clone)]
pub struct ParticipationRewardResult {
pub provider_id: String,
pub reward_lux_wei: u128,
pub weight: f64,
pub weight_share: f64,
pub tier: CCTier,
pub modeling_level: ModelingLevel,
}
#[derive(Debug, Clone)]
pub struct TaskRewardResult {
pub provider_id: String,
pub task_id: String,
pub reward_lux_wei: u128,
pub modeling_level: ModelingLevel,
pub compute_units: u64,
}
pub struct AIRewardPool {
pub providers: HashMap<String, AIProvider>,
pub epoch_number: u64,
pub epoch_duration: Duration,
pub total_pool_lux_wei: u128,
pub participation_share: f64,
pub task_share: f64,
}
impl AIRewardPool {
pub fn new(epoch_duration: Duration) -> Self {
Self {
providers: HashMap::new(),
epoch_number: 0,
epoch_duration,
total_pool_lux_wei: 0,
participation_share: 0.30, task_share: 0.70, }
}
pub fn register_provider(&mut self, provider: AIProvider) -> Result<(), &'static str> {
if provider.provider_id.is_empty() {
return Err("provider ID required");
}
let min_stake = CCTier::Tier4Standard.min_stake_lux();
if provider.stake_lux < min_stake {
return Err("insufficient stake");
}
self.providers.insert(provider.provider_id.clone(), provider);
Ok(())
}
pub fn calculate_participation_rewards(
&self,
max_heartbeat_age: Duration,
) -> Vec<ParticipationRewardResult> {
let participation_pool = (self.total_pool_lux_wei as f64 * self.participation_share) as u128;
let mut total_weight = 0.0;
let mut online_providers: Vec<&AIProvider> = Vec::new();
for provider in self.providers.values() {
if !provider.is_online(max_heartbeat_age) {
continue;
}
if provider.attestation.as_ref().map(|a| a.is_valid()).unwrap_or(false) {
let weight = provider.reward_weight();
total_weight += weight;
online_providers.push(provider);
}
}
if total_weight == 0.0 || online_providers.is_empty() {
return Vec::new();
}
let mut results = Vec::with_capacity(online_providers.len());
for provider in online_providers {
let weight = provider.reward_weight();
let share = weight / total_weight;
let reward = (participation_pool as f64 * share) as u128;
results.push(ParticipationRewardResult {
provider_id: provider.provider_id.clone(),
reward_lux_wei: reward,
weight,
weight_share: share,
tier: provider.effective_tier(),
modeling_level: provider.max_modeling_level,
});
}
results
}
pub fn calculate_task_reward(
&self,
provider: &AIProvider,
task_id: String,
modeling_level: ModelingLevel,
compute_units: u64,
) -> TaskRewardResult {
let base_rate_wei: u128 = 1_000_000_000_000;
let mut reward = base_rate_wei * compute_units as u128;
let tier_mult = provider.effective_tier().reward_multiplier();
reward = (reward as f64 * tier_mult) as u128;
let level_mult = modeling_level.base_reward_multiplier();
reward = (reward as f64 * level_mult) as u128;
TaskRewardResult {
provider_id: provider.provider_id.clone(),
task_id,
reward_lux_wei: reward,
modeling_level,
compute_units,
}
}
}
pub fn calculate_block_reward_split(total_block_reward_wei: u128) -> (u128, u128) {
let validator_reward = (total_block_reward_wei as f64 * VALIDATOR_REWARD_SHARE) as u128;
let ai_pool_reward = total_block_reward_wei - validator_reward;
(validator_reward, ai_pool_reward)
}
pub fn random_mining_eligibility(
provider: &AIProvider,
max_heartbeat_age: Duration,
) -> Result<(), &'static str> {
if !provider.is_online(max_heartbeat_age) {
return Err("provider offline");
}
let attestation = provider.attestation.as_ref().ok_or("no attestation")?;
if !attestation.is_valid() {
return Err("attestation expired");
}
let min_stake = provider.effective_tier().min_stake_lux();
if provider.stake_lux < min_stake {
return Err("insufficient stake");
}
Ok(())
}
#[derive(Debug)]
pub struct EpochRewardSummary {
pub epoch_number: u64,
pub total_block_rewards_wei: u128,
pub validator_rewards_wei: u128,
pub ai_pool_rewards_wei: u128,
pub participation_rewards_wei: u128,
pub task_rewards_wei: u128,
pub online_providers: u64,
pub total_providers: u64,
pub tier_distribution: HashMap<CCTier, u64>,
}
impl AIRewardPool {
pub fn calculate_epoch_rewards(
&mut self,
total_block_rewards_wei: u128,
max_heartbeat_age: Duration,
) -> EpochRewardSummary {
let (validator_rewards, ai_pool_rewards) =
calculate_block_reward_split(total_block_rewards_wei);
self.total_pool_lux_wei = ai_pool_rewards;
let participation_pool = (ai_pool_rewards as f64 * self.participation_share) as u128;
let task_pool = ai_pool_rewards - participation_pool;
let mut tier_dist: HashMap<CCTier, u64> = HashMap::new();
let mut online_count = 0u64;
for provider in self.providers.values() {
if provider.is_online(max_heartbeat_age) {
online_count += 1;
let tier = provider.effective_tier();
*tier_dist.entry(tier).or_insert(0) += 1;
}
}
EpochRewardSummary {
epoch_number: self.epoch_number,
total_block_rewards_wei,
validator_rewards_wei: validator_rewards,
ai_pool_rewards_wei: ai_pool_rewards,
participation_rewards_wei: participation_pool,
task_rewards_wei: task_pool,
online_providers: online_count,
total_providers: self.providers.len() as u64,
tier_distribution: tier_dist,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_modeling_level_multipliers() {
assert!(ModelingLevel::InferenceLight.base_reward_multiplier() <
ModelingLevel::InferenceStandard.base_reward_multiplier());
assert!(ModelingLevel::InferenceStandard.base_reward_multiplier() <
ModelingLevel::InferenceHeavy.base_reward_multiplier());
assert!(ModelingLevel::InferenceHeavy.base_reward_multiplier() <
ModelingLevel::Training.base_reward_multiplier());
assert!(ModelingLevel::Training.base_reward_multiplier() <
ModelingLevel::Specialized.base_reward_multiplier());
}
#[test]
fn test_block_reward_split() {
let total = 100_000_000_000_000_000_000u128;
let (validator, ai_pool) = calculate_block_reward_split(total);
let expected_validator = 90_000_000_000_000_000_000u128;
assert_eq!(validator, expected_validator);
let expected_ai = 10_000_000_000_000_000_000u128;
assert_eq!(ai_pool, expected_ai);
assert_eq!(validator + ai_pool, total);
}
#[test]
fn test_provider_reward_weight() {
let mut provider = AIProvider::new(
"test".to_string(),
100_000,
ModelingLevel::InferenceHeavy,
);
provider.attestation = Some(TierAttestation::new_valid(
CCTier::Tier1GpuNativeCC,
Duration::from_secs(3600 * 5), ));
provider.reputation_score = 0.9;
provider.consecutive_epochs = 500;
let weight = provider.reward_weight();
assert!(weight > 10.0, "Expected weight > 10, got {}", weight);
assert!(weight < 50.0, "Expected weight < 50, got {}", weight);
}
#[test]
fn test_pool_registration() {
let mut pool = AIRewardPool::new(Duration::from_secs(3600));
let provider = AIProvider::new(
"provider-1".to_string(),
50_000,
ModelingLevel::InferenceStandard,
);
assert!(pool.register_provider(provider).is_ok());
assert!(pool.providers.contains_key("provider-1"));
let low_stake = AIProvider::new(
"low-stake".to_string(),
100, ModelingLevel::InferenceLight,
);
assert!(pool.register_provider(low_stake).is_err());
}
#[test]
fn test_random_mining_eligibility() {
let mut provider = AIProvider::new(
"test".to_string(),
50_000,
ModelingLevel::InferenceStandard,
);
assert!(random_mining_eligibility(&provider, Duration::from_secs(300)).is_err());
provider.attestation = Some(TierAttestation::new_valid(
CCTier::Tier2ConfidentialVM,
Duration::from_secs(3600 * 23),
));
assert!(random_mining_eligibility(&provider, Duration::from_secs(300)).is_ok());
provider.last_heartbeat = SystemTime::now() - Duration::from_secs(600);
assert!(random_mining_eligibility(&provider, Duration::from_secs(300)).is_err());
}
}