use crate::types::UUID;
use crate::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearningFeedback {
pub response_id: UUID,
pub score: f32,
pub text: Option<String>,
pub source: FeedbackSource,
pub timestamp: i64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FeedbackSource {
User,
Evaluator,
System,
Expert,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingResult {
pub success: bool,
pub loss: Option<f32>,
pub examples_count: usize,
pub duration_secs: f64,
pub error: Option<String>,
pub metrics: HashMap<String, f64>,
}
#[async_trait]
pub trait LearningProvider: Send + Sync {
fn name(&self) -> &str;
async fn collect_feedback(&self, feedback: LearningFeedback) -> Result<()>;
async fn get_feedback(&self, limit: usize) -> Result<Vec<LearningFeedback>>;
fn supports_training(&self) -> bool {
false
}
async fn train(&self) -> Result<TrainingResult> {
Ok(TrainingResult {
success: false,
loss: None,
examples_count: 0,
duration_secs: 0.0,
error: Some("Training requires enterprise license".to_string()),
metrics: HashMap::new(),
})
}
async fn enable_continual_learning(&self) -> Result<bool> {
Ok(false)
}
async fn load_adapter(&self, _adapter_name: &str) -> Result<bool> {
Ok(false)
}
async fn list_adapters(&self) -> Result<Vec<String>> {
Ok(vec![])
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PiiFinding {
pub pii_type: PiiType,
pub start: usize,
pub end: usize,
pub matched_text: String,
pub severity: Severity,
pub confidence: f32,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum PiiType {
Ssn,
Email,
Phone,
CreditCard,
IpAddress,
ApiKey,
Name,
Address,
DateOfBirth,
MedicalId,
Custom(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum Severity {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ComplianceFramework {
Hipaa,
Gdpr,
Fda,
PciDss,
Soc2,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceCheckResult {
pub framework: ComplianceFramework,
pub compliant: bool,
pub findings: Vec<ComplianceFinding>,
pub recommendations: Vec<String>,
pub checked_at: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceFinding {
pub code: String,
pub description: String,
pub severity: Severity,
pub remediation: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditEntry {
pub id: UUID,
pub timestamp: i64,
pub actor: String,
pub action: String,
pub resource: String,
pub outcome: AuditOutcome,
pub context: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AuditOutcome {
Success,
Failure,
Denied,
}
#[async_trait]
pub trait ComplianceProvider: Send + Sync {
fn name(&self) -> &str;
async fn scan_pii(&self, text: &str) -> Result<Vec<PiiFinding>>;
fn redact(&self, text: &str) -> String;
fn supports_framework(&self, _framework: ComplianceFramework) -> bool {
false
}
async fn check_compliance(
&self,
_framework: ComplianceFramework,
_context: &str,
) -> Result<Option<ComplianceCheckResult>> {
Ok(None)
}
async fn audit_log(&self, _entry: AuditEntry) -> Result<()> {
Ok(())
}
async fn get_audit_logs(
&self,
_start: i64,
_end: i64,
_limit: usize,
) -> Result<Vec<AuditEntry>> {
Ok(vec![])
}
async fn run_audit(&self, _auditor: &str) -> Result<Option<ComplianceAuditReport>> {
Ok(None)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComplianceAuditReport {
pub id: UUID,
pub auditor: String,
pub frameworks: Vec<ComplianceFramework>,
pub compliance_percentage: f32,
pub findings: HashMap<String, Vec<ComplianceFinding>>,
pub generated_at: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeInfo {
pub id: String,
pub address: String,
pub status: NodeStatus,
pub resources: NodeResources,
pub last_heartbeat: i64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NodeStatus {
Healthy,
Overloaded,
Unhealthy,
Draining,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeResources {
pub cpu_cores: f32,
pub memory_mb: u64,
pub gpu_devices: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedTask {
pub id: UUID,
pub task_type: String,
pub payload: serde_json::Value,
pub requirements: NodeResources,
pub priority: i32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributedTaskResult {
pub task_id: UUID,
pub executed_by: String,
pub success: bool,
pub result: Option<serde_json::Value>,
pub error: Option<String>,
pub duration_ms: u64,
}
#[async_trait]
pub trait DistributedExecutor: Send + Sync {
fn name(&self) -> &str;
fn is_distributed(&self) -> bool {
false
}
async fn get_nodes(&self) -> Result<Vec<NodeInfo>> {
Ok(vec![])
}
async fn submit_task(&self, task: DistributedTask) -> Result<UUID>;
async fn get_task_result(&self, task_id: UUID) -> Result<Option<DistributedTaskResult>>;
async fn wait_for_task(
&self,
task_id: UUID,
timeout_ms: u64,
) -> Result<Option<DistributedTaskResult>>;
async fn join_cluster(&self, _coordinator: &str) -> Result<bool> {
Ok(false)
}
async fn leave_cluster(&self) -> Result<bool> {
Ok(false)
}
async fn cluster_status(&self) -> Result<Option<ClusterStatus>> {
Ok(None)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterStatus {
pub name: String,
pub node_count: usize,
pub healthy_nodes: usize,
pub total_resources: NodeResources,
pub queued_tasks: usize,
pub running_tasks: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyRule {
pub id: String,
pub name: String,
pub rule_type: PolicyRuleType,
pub config: serde_json::Value,
pub enabled: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum PolicyRuleType {
RateLimit,
CostCap,
ContentFilter,
ToolPermission,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PolicyDecision {
pub allowed: bool,
pub reason: String,
pub matched_rules: Vec<String>,
}
#[async_trait]
pub trait PolicyProvider: Send + Sync {
fn name(&self) -> &str;
fn is_enabled(&self) -> bool {
false
}
async fn check_policy(
&self,
_action: &str,
_context: &HashMap<String, serde_json::Value>,
) -> Result<PolicyDecision> {
Ok(PolicyDecision {
allowed: true,
reason: "Policy engine not enabled".to_string(),
matched_rules: vec![],
})
}
async fn get_rules(&self) -> Result<Vec<PolicyRule>> {
Ok(vec![])
}
async fn set_rule(&self, _rule: PolicyRule) -> Result<bool> {
Ok(false)
}
async fn delete_rule(&self, _rule_id: &str) -> Result<bool> {
Ok(false)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Identity {
pub id: UUID,
pub handle: String,
pub consents: Vec<ConsentScope>,
pub retention_days: i32,
pub created_at: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConsentScope {
pub scope: String,
pub granted: bool,
pub granted_at: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataExportRequest {
pub id: UUID,
pub identity_id: UUID,
pub format: String,
pub status: DataRequestStatus,
pub download_url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DataDeletionRequest {
pub id: UUID,
pub identity_id: UUID,
pub status: DataRequestStatus,
pub completed_at: Option<i64>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DataRequestStatus {
Pending,
InProgress,
Completed,
Failed,
}
#[async_trait]
pub trait IdentityProvider: Send + Sync {
fn name(&self) -> &str;
fn is_enabled(&self) -> bool {
false
}
async fn get_identity(&self, _id: UUID) -> Result<Option<Identity>> {
Ok(None)
}
async fn update_consent(
&self,
_identity_id: UUID,
_scope: &str,
_granted: bool,
) -> Result<bool> {
Ok(false)
}
async fn request_export(&self, _identity_id: UUID, _format: &str) -> Result<Option<UUID>> {
Ok(None)
}
async fn request_deletion(&self, _identity_id: UUID) -> Result<Option<UUID>> {
Ok(None)
}
async fn set_retention(&self, _identity_id: UUID, _days: i32) -> Result<bool> {
Ok(false)
}
}
pub struct BasicLearningProvider {
feedback: std::sync::RwLock<Vec<LearningFeedback>>,
}
impl BasicLearningProvider {
pub fn new() -> Self {
Self {
feedback: std::sync::RwLock::new(Vec::new()),
}
}
}
impl Default for BasicLearningProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl LearningProvider for BasicLearningProvider {
fn name(&self) -> &str {
"basic-learning"
}
async fn collect_feedback(&self, feedback: LearningFeedback) -> Result<()> {
let mut store = self.feedback.write().unwrap();
store.push(feedback);
if store.len() > 1000 {
store.remove(0);
}
Ok(())
}
async fn get_feedback(&self, limit: usize) -> Result<Vec<LearningFeedback>> {
let store = self.feedback.read().unwrap();
Ok(store.iter().rev().take(limit).cloned().collect())
}
}
pub struct BasicComplianceProvider {
patterns: Vec<(PiiType, regex::Regex)>,
}
impl BasicComplianceProvider {
pub fn new() -> Self {
let patterns = vec![
(
PiiType::Ssn,
regex::Regex::new(r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b").unwrap(),
),
(
PiiType::Email,
regex::Regex::new(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b").unwrap(),
),
(
PiiType::Phone,
regex::Regex::new(r"\b(\+?1[-.\s]?)?(\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b")
.unwrap(),
),
(
PiiType::CreditCard,
regex::Regex::new(r"\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b").unwrap(),
),
(
PiiType::ApiKey,
regex::Regex::new(r"\b(sk-[a-zA-Z0-9]{32,}|api[_-]?key[_-]?[a-zA-Z0-9]{16,})\b")
.unwrap(),
),
];
Self { patterns }
}
}
impl Default for BasicComplianceProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ComplianceProvider for BasicComplianceProvider {
fn name(&self) -> &str {
"basic-compliance"
}
async fn scan_pii(&self, text: &str) -> Result<Vec<PiiFinding>> {
let mut findings = Vec::new();
for (pii_type, pattern) in &self.patterns {
for m in pattern.find_iter(text) {
findings.push(PiiFinding {
pii_type: pii_type.clone(),
start: m.start(),
end: m.end(),
matched_text: m.as_str().to_string(),
severity: match pii_type {
PiiType::Ssn | PiiType::CreditCard | PiiType::ApiKey => Severity::Critical,
PiiType::Email | PiiType::Phone => Severity::Medium,
_ => Severity::Low,
},
confidence: 0.9,
});
}
}
Ok(findings)
}
fn redact(&self, text: &str) -> String {
let mut result = text.to_string();
for (pii_type, pattern) in &self.patterns {
let replacement = match pii_type {
PiiType::Ssn => "[SSN]",
PiiType::Email => "[EMAIL]",
PiiType::Phone => "[PHONE]",
PiiType::CreditCard => "[CREDIT_CARD]",
PiiType::ApiKey => "[API_KEY]",
_ => "[REDACTED]",
};
result = pattern.replace_all(&result, replacement).to_string();
}
result
}
}
pub struct ExtensionRegistry {
pub learning: Option<Arc<dyn LearningProvider>>,
pub compliance: Option<Arc<dyn ComplianceProvider>>,
pub distributed: Option<Arc<dyn DistributedExecutor>>,
pub policy: Option<Arc<dyn PolicyProvider>>,
pub identity: Option<Arc<dyn IdentityProvider>>,
}
impl ExtensionRegistry {
pub fn new() -> Self {
Self {
learning: Some(Arc::new(BasicLearningProvider::new())),
compliance: Some(Arc::new(BasicComplianceProvider::new())),
distributed: None,
policy: None,
identity: None,
}
}
pub fn empty() -> Self {
Self {
learning: None,
compliance: None,
distributed: None,
policy: None,
identity: None,
}
}
pub fn with_learning(mut self, provider: Arc<dyn LearningProvider>) -> Self {
self.learning = Some(provider);
self
}
pub fn with_compliance(mut self, provider: Arc<dyn ComplianceProvider>) -> Self {
self.compliance = Some(provider);
self
}
pub fn with_distributed(mut self, executor: Arc<dyn DistributedExecutor>) -> Self {
self.distributed = Some(executor);
self
}
pub fn with_policy(mut self, provider: Arc<dyn PolicyProvider>) -> Self {
self.policy = Some(provider);
self
}
pub fn with_identity(mut self, provider: Arc<dyn IdentityProvider>) -> Self {
self.identity = Some(provider);
self
}
pub fn has_enterprise_features(&self) -> bool {
self.learning
.as_ref()
.map(|l| l.supports_training())
.unwrap_or(false)
|| self
.compliance
.as_ref()
.map(|c| c.supports_framework(ComplianceFramework::Hipaa))
.unwrap_or(false)
|| self
.distributed
.as_ref()
.map(|d| d.is_distributed())
.unwrap_or(false)
|| self.policy.as_ref().map(|p| p.is_enabled()).unwrap_or(false)
|| self
.identity
.as_ref()
.map(|i| i.is_enabled())
.unwrap_or(false)
}
}
impl Default for ExtensionRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic_learning_provider() {
let provider = BasicLearningProvider::new();
let feedback = LearningFeedback {
response_id: uuid::Uuid::new_v4(),
score: 0.8,
text: Some("Good response".to_string()),
source: FeedbackSource::User,
timestamp: chrono::Utc::now().timestamp(),
};
provider.collect_feedback(feedback.clone()).await.unwrap();
let collected = provider.get_feedback(10).await.unwrap();
assert_eq!(collected.len(), 1);
assert_eq!(collected[0].score, 0.8);
}
#[tokio::test]
async fn test_basic_compliance_provider() {
let provider = BasicComplianceProvider::new();
let text = "My SSN is 123-45-6789 and email is test@example.com";
let findings = provider.scan_pii(text).await.unwrap();
assert!(findings.iter().any(|f| f.pii_type == PiiType::Ssn));
assert!(findings.iter().any(|f| f.pii_type == PiiType::Email));
let redacted = provider.redact(text);
assert!(redacted.contains("[SSN]"));
assert!(redacted.contains("[EMAIL]"));
assert!(!redacted.contains("123-45-6789"));
}
#[tokio::test]
async fn test_training_not_available_in_consumer() {
let provider = BasicLearningProvider::new();
assert!(!provider.supports_training());
let result = provider.train().await.unwrap();
assert!(!result.success);
assert!(result.error.is_some());
}
#[test]
fn test_extension_registry() {
let registry = ExtensionRegistry::new();
assert!(registry.learning.is_some());
assert!(registry.compliance.is_some());
assert!(!registry.has_enterprise_features());
}
}