use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelWeights {
pub model_id: String,
pub version: u32,
pub weights: HashMap<String, f64>,
pub bias: f64,
pub training_count: usize,
pub timestamp: i64,
pub schema_hash: String,
}
impl ModelWeights {
pub fn new(weights: HashMap<String, f64>, bias: f64, training_count: usize) -> Self {
Self {
model_id: uuid::Uuid::new_v4().to_string(),
version: 1,
weights,
bias,
training_count,
timestamp: chrono::Utc::now().timestamp(),
schema_hash: Self::compute_schema_hash(),
}
}
fn compute_schema_hash() -> String {
let features = [
"status_code",
"response_length",
"response_time",
"payload_reflected",
"has_error_patterns",
"differs_from_baseline",
"severity",
"confidence",
"sql_injection",
"xss",
"csrf",
"ssrf",
"xxe",
"command_injection",
"path_traversal",
"idor",
"auth_bypass",
"jwt",
"nosql_injection",
"cors",
"open_redirect",
"file_upload",
"deserialization",
"ssti",
"prototype_pollution",
"race_condition",
"bola",
"info_disclosure",
];
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
features.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
pub fn is_compatible(&self) -> bool {
self.schema_hash == Self::compute_schema_hash()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregatedModel {
pub global_version: u32,
pub weights: ModelWeights,
pub contributor_count: usize,
pub total_training_examples: usize,
pub server_signature: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCategory {
pub name: String,
pub feature_count: usize,
pub description: Option<String>,
}
pub struct FederatedClient {
server_url: String,
global_model: Option<AggregatedModel>,
license_key: Option<String>,
}
impl FederatedClient {
pub fn new() -> Result<Self> {
let data_dir = Self::get_data_dir()?;
fs::create_dir_all(&data_dir)?;
Ok(Self {
server_url: "https://lonkero.bountyy.fi/api/federated/v1".to_string(),
global_model: None,
license_key: None,
})
}
pub fn set_license_key(&mut self, key: String) {
self.license_key = Some(key);
}
pub fn load_license_key(&mut self) {
if let Ok(key) = std::env::var("LONKERO_LICENSE_KEY") {
if !key.is_empty() {
self.license_key = Some(key);
return;
}
}
if let Ok(entry) = keyring::Entry::new("lonkero", "license_key") {
if let Ok(key) = entry.get_password() {
if !key.is_empty() {
debug!("ML: License key loaded from OS keychain");
self.license_key = Some(key);
return;
}
}
}
if let Some(config_dir) = dirs::config_dir() {
let license_file = config_dir.join("lonkero").join("license.key");
if license_file.exists() {
if let Ok(content) = std::fs::read_to_string(&license_file) {
let key = content.trim().to_string();
if !key.is_empty() {
debug!("ML: License key loaded from legacy config file");
self.license_key = Some(key);
}
}
}
}
}
fn get_data_dir() -> Result<PathBuf> {
let home = dirs::home_dir().context("Could not determine home directory")?;
Ok(home.join(".lonkero").join("federated"))
}
pub async fn fetch_global_model(&mut self) -> Result<Option<AggregatedModel>> {
info!("Fetching detection model from server...");
let client = reqwest::Client::new();
let request = client
.get(format!("{}/model/latest", self.server_url))
.timeout(std::time::Duration::from_secs(30));
let response = request.send().await;
match response {
Ok(resp) if resp.status().is_success() => {
let model: AggregatedModel = resp.json().await?;
info!(
"Fetched detection model v{} ({} contributors, {} examples)",
model.global_version, model.contributor_count, model.total_training_examples
);
if !model.weights.is_compatible() {
warn!("Detection model has incompatible schema, skipping");
return Ok(None);
}
self.global_model = Some(model.clone());
self.save_global_model(&model)?;
Ok(Some(model))
}
Ok(resp) if resp.status().as_u16() == 401 => {
warn!("Model download returned 401 - server configuration issue");
warn!("Please contact support at info@bountyy.fi");
Ok(self.load_cached_global_model()?)
}
Ok(resp) => {
debug!("No detection model available: {}", resp.status());
Ok(self.load_cached_global_model()?)
}
Err(e) => {
debug!("Could not reach server: {}", e);
Ok(self.load_cached_global_model()?)
}
}
}
pub async fn fetch_categories(&self) -> Result<Vec<ModelCategory>> {
let client = reqwest::Client::new();
let response = client
.get(format!("{}/model/categories", self.server_url))
.timeout(std::time::Duration::from_secs(30))
.send()
.await;
match response {
Ok(resp) if resp.status().is_success() => {
let categories: Vec<ModelCategory> = resp.json().await?;
info!("Fetched {} detection categories", categories.len());
Ok(categories)
}
Ok(resp) => {
debug!("Could not fetch categories: {}", resp.status());
Ok(Vec::new())
}
Err(e) => {
debug!("Could not reach server for categories: {}", e);
Ok(Vec::new())
}
}
}
pub fn get_model(&self) -> Option<&AggregatedModel> {
self.global_model.as_ref()
}
fn save_global_model(&self, model: &AggregatedModel) -> Result<()> {
let path = Self::get_data_dir()?.join("global_model.json");
let json = serde_json::to_string_pretty(model)?;
fs::write(path, json)?;
Ok(())
}
fn load_cached_global_model(&self) -> Result<Option<AggregatedModel>> {
let path = Self::get_data_dir()?.join("global_model.json");
if path.exists() {
let content = fs::read_to_string(path)?;
let model: AggregatedModel = serde_json::from_str(&content)?;
debug!("Loaded cached detection model v{}", model.global_version);
Ok(Some(model))
} else {
Ok(None)
}
}
pub fn load_cached_model() -> Result<AggregatedModel> {
let path = Self::get_data_dir()?.join("global_model.json");
let content = fs::read_to_string(&path)
.context("No cached model found at ~/.lonkero/federated/global_model.json")?;
let model: AggregatedModel = serde_json::from_str(&content)
.context("Failed to parse cached model")?;
info!("Loaded cached detection model v{}", model.global_version);
Ok(model)
}
pub async fn fetch_and_cache_model(&mut self) -> Result<AggregatedModel> {
match self.fetch_global_model().await? {
Some(model) => Ok(model),
None => Err(anyhow::anyhow!("No model available from server")),
}
}
pub fn get_stats(&self) -> FederatedStats {
FederatedStats {
has_global_model: self.global_model.is_some(),
global_version: self.global_model.as_ref().map(|m| m.global_version),
global_contributors: self.global_model.as_ref().map(|m| m.contributor_count),
}
}
}
impl Default for FederatedClient {
fn default() -> Self {
Self::new().expect("Failed to create model distribution client")
}
}
#[derive(Debug)]
pub struct FederatedStats {
pub has_global_model: bool,
pub global_version: Option<u32>,
pub global_contributors: Option<usize>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_schema_compatibility() {
let weights = ModelWeights::new(HashMap::new(), 0.0, 50);
assert!(weights.is_compatible());
}
#[test]
fn test_model_weights_creation() {
let weights = ModelWeights::new(
[("feature1".to_string(), 0.8)].into_iter().collect(),
-0.42,
100,
);
assert_eq!(weights.bias, -0.42);
assert_eq!(weights.training_count, 100);
assert!(weights.weights.contains_key("feature1"));
}
#[test]
fn test_federated_client_creation() {
let client = FederatedClient::new();
assert!(client.is_ok());
let client = client.unwrap();
assert!(client.global_model.is_none());
assert!(client.license_key.is_none());
}
#[test]
fn test_set_license_key() {
let mut client = FederatedClient::new().unwrap();
client.set_license_key("test-key-123".to_string());
assert_eq!(client.license_key.as_deref(), Some("test-key-123"));
}
#[test]
fn test_stats_no_model() {
let client = FederatedClient::new().unwrap();
let stats = client.get_stats();
assert!(!stats.has_global_model);
assert!(stats.global_version.is_none());
assert!(stats.global_contributors.is_none());
}
}