use crate::neural_path::NeuralPathValidator;
use ai3_lib::tensor::{Tensor, TensorData, TensorShape};
use ai3_lib::MiningTask;
use pot_o_core::TribeResult;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Challenge {
pub id: String,
pub slot: u64,
pub slot_hash: String,
pub operation_type: String,
pub input_tensor: Tensor,
pub difficulty: u64,
pub mml_threshold: f64,
pub path_distance_max: u32,
pub max_tensor_dim: usize,
pub created_at: chrono::DateTime<chrono::Utc>,
pub expires_at: chrono::DateTime<chrono::Utc>,
}
impl Challenge {
pub fn is_expired(&self) -> bool {
chrono::Utc::now() > self.expires_at
}
pub fn to_mining_task(&self, requester: &str) -> MiningTask {
MiningTask::new(
self.operation_type.clone(),
vec![self.input_tensor.clone()],
self.difficulty,
50_000_000, 300,
requester.to_string(),
)
}
}
pub struct ChallengeGenerator {
pub base_difficulty: u64,
pub base_mml_threshold: f64,
pub base_path_distance: u32,
pub max_tensor_dim: usize,
pub challenge_ttl_secs: i64,
}
impl Default for ChallengeGenerator {
fn default() -> Self {
let base_path_distance = NeuralPathValidator::default()
.layer_widths
.iter()
.sum::<usize>() as u32;
Self {
base_difficulty: 2,
base_mml_threshold: 2.0,
base_path_distance,
max_tensor_dim: pot_o_core::ESP_MAX_TENSOR_DIM,
challenge_ttl_secs: 120,
}
}
}
const OPERATIONS: &[&str] = &[
"matrix_multiply",
"convolution",
"relu",
"sigmoid",
"tanh",
"dot_product",
"normalize",
];
impl ChallengeGenerator {
pub fn new(difficulty: u64, max_tensor_dim: usize) -> Self {
Self {
base_difficulty: difficulty,
max_tensor_dim,
..Default::default()
}
}
pub fn generate(&self, slot: u64, slot_hash_hex: &str) -> TribeResult<Challenge> {
let hash_bytes = hex::decode(slot_hash_hex).map_err(|e| {
pot_o_core::TribeError::InvalidOperation(format!("Invalid slot hash hex: {e}"))
})?;
let op_index = hash_bytes.first().copied().unwrap_or(0) as usize % OPERATIONS.len();
let operation_type = OPERATIONS[op_index].to_string();
let input_tensor = self.derive_input_tensor(&hash_bytes)?;
let difficulty = self.compute_difficulty(slot);
let mml_threshold = self.base_mml_threshold / (1.0 + (difficulty as f64).log2().max(0.0));
let path_distance_max = self
.base_path_distance
.saturating_sub((difficulty as u32).min(self.base_path_distance - 1));
let now = chrono::Utc::now();
let challenge_id = {
let mut h = Sha256::new();
h.update(slot.to_le_bytes());
h.update(&hash_bytes);
hex::encode(h.finalize())
};
Ok(Challenge {
id: challenge_id,
slot,
slot_hash: slot_hash_hex.to_string(),
operation_type,
input_tensor,
difficulty,
mml_threshold,
path_distance_max,
max_tensor_dim: self.max_tensor_dim,
created_at: now,
expires_at: now + chrono::Duration::seconds(self.challenge_ttl_secs),
})
}
fn derive_input_tensor(&self, hash_bytes: &[u8]) -> TribeResult<Tensor> {
let dim_byte = hash_bytes.get(1).copied().unwrap_or(4);
let dim = ((dim_byte as usize % self.max_tensor_dim) + 2).min(self.max_tensor_dim);
let total = dim * dim;
let mut floats: Vec<f32> = hash_bytes.iter().map(|&b| b as f32 / 255.0).collect();
while floats.len() < total {
let seed = floats.len() as f32 * 0.618_034;
floats.push(seed.fract());
}
floats.truncate(total);
Tensor::new(TensorShape::new(vec![dim, dim]), TensorData::F32(floats))
}
fn compute_difficulty(&self, slot: u64) -> u64 {
let epoch = slot / 10_000;
self.base_difficulty + epoch.min(10)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_challenge_generation() {
let gen = ChallengeGenerator::default();
let hash = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
let challenge = gen.generate(100, hash).unwrap();
assert!(!challenge.id.is_empty());
assert!(challenge.mml_threshold > 0.0);
assert!(challenge.mml_threshold <= gen.base_mml_threshold);
}
#[test]
fn test_deterministic_operation() {
let gen = ChallengeGenerator::default();
let hash = "ff00112233445566778899aabbccddeeff00112233445566778899aabbccddeeff";
let c1 = gen.generate(42, hash).unwrap();
let c2 = gen.generate(42, hash).unwrap();
assert_eq!(c1.operation_type, c2.operation_type);
assert_eq!(c1.id, c2.id);
}
}