Skip to main content

ternary_federated/
lib.rs

1//! Federated learning for ternary agents.
2//!
3//! Multiple populations share strategy insights without sharing raw data.
4//! Each node runs local ternary evolution, and a federated round aggregates
5//! strategy summaries across nodes using configurable aggregation methods,
6//! all while tracking a differential-privacy-style privacy budget.
7
8use std::fmt;
9
10// ---------------------------------------------------------------------------
11// Ternary value
12// ---------------------------------------------------------------------------
13
14/// A ternary value: -1, 0, or +1.
15#[derive(Clone, Copy, PartialEq, Eq, Debug)]
16pub enum Ternary {
17    Neg = -1,
18    Zero = 0,
19    Pos = 1,
20}
21
22impl Ternary {
23    /// Convert an i8 to a Ternary, clamping to the nearest valid value.
24    pub fn from_i8(v: i8) -> Self {
25        match v {
26            ..=-1 => Ternary::Neg,
27            0 => Ternary::Zero,
28            1.. => Ternary::Pos,
29        }
30    }
31
32    /// Convert to i8.
33    pub fn as_i8(self) -> i8 {
34        self as i8
35    }
36
37    /// Convert to f64.
38    pub fn as_f64(self) -> f64 {
39        self.as_i8() as f64
40    }
41
42    /// Pick a random ternary value (deterministic simple PRNG).
43    pub fn random(state: &mut u64) -> Self {
44        *state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
45        let v = (*state >> 33) % 3;
46        match v {
47            0 => Ternary::Neg,
48            1 => Ternary::Zero,
49            _ => Ternary::Pos,
50        }
51    }
52}
53
54impl fmt::Display for Ternary {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        write!(f, "{}", self.as_i8())
57    }
58}
59
60// ---------------------------------------------------------------------------
61// Strategy
62// ---------------------------------------------------------------------------
63
64/// A strategy is a vector of ternary values.
65#[derive(Clone, PartialEq, Eq, Debug)]
66pub struct Strategy {
67    pub values: Vec<Ternary>,
68}
69
70impl Strategy {
71    /// Create a random strategy of given length.
72    pub fn random(len: usize, state: &mut u64) -> Self {
73        let values = (0..len).map(|_| Ternary::random(state)).collect();
74        Self { values }
75    }
76
77    /// Create a strategy of all zeros.
78    pub fn zeros(len: usize) -> Self {
79        Self {
80            values: vec![Ternary::Zero; len],
81        }
82    }
83
84    /// Length of the strategy.
85    pub fn len(&self) -> usize {
86        self.values.len()
87    }
88
89    /// Whether the strategy is empty.
90    pub fn is_empty(&self) -> bool {
91        self.values.is_empty()
92    }
93
94    /// Compute fitness against a target (number of matching positions).
95    pub fn fitness_against(&self, target: &Strategy) -> f64 {
96        if self.len() != target.len() {
97            return 0.0;
98        }
99        let matches = self
100            .values
101            .iter()
102            .zip(&target.values)
103            .filter(|(a, b)| a == b)
104            .count();
105        matches as f64 / self.len() as f64
106    }
107
108    /// Mutate one random position in the strategy.
109    pub fn mutate(&mut self, state: &mut u64) {
110        if self.is_empty() {
111            return;
112        }
113        *state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
114        let idx = (*state as usize) % self.len();
115        *state = state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
116        let v = (*state >> 33) % 3;
117        self.values[idx] = match v {
118            0 => Ternary::Neg,
119            1 => Ternary::Zero,
120            _ => Ternary::Pos,
121        };
122    }
123}
124
125// ---------------------------------------------------------------------------
126// Agent
127// ---------------------------------------------------------------------------
128
129/// A single ternary agent with a strategy and fitness.
130#[derive(Clone, Debug)]
131pub struct Agent {
132    pub strategy: Strategy,
133    pub fitness: f64,
134}
135
136impl Agent {
137    /// Create a random agent with strategy of given length.
138    pub fn random(strategy_len: usize, state: &mut u64) -> Self {
139        Self {
140            strategy: Strategy::random(strategy_len, state),
141            fitness: 0.0,
142        }
143    }
144
145    /// Evaluate fitness against a target strategy.
146    pub fn evaluate(&mut self, target: &Strategy) {
147        self.fitness = self.strategy.fitness_against(target);
148    }
149}
150
151// ---------------------------------------------------------------------------
152// Node
153// ---------------------------------------------------------------------------
154
155/// A local ternary population with its own evolution.
156#[derive(Clone, Debug)]
157pub struct Node {
158    /// The agents in this node's population.
159    pub agents: Vec<Agent>,
160    /// This node's secret target strategy (never shared).
161    pub target: Strategy,
162    /// PRNG state for deterministic evolution.
163    pub rng_state: u64,
164    /// Unique node identifier.
165    pub id: usize,
166    /// Best fitness achieved so far.
167    pub best_fitness: f64,
168    /// Strategy that achieved best fitness.
169    pub best_strategy: Strategy,
170}
171
172impl Node {
173    /// Create a new node with `population_size` agents and strategies of `strategy_len`.
174    pub fn new(population_size: usize, strategy_len: usize) -> Self {
175        Self::with_id(population_size, strategy_len, 0)
176    }
177
178    /// Create a node with a specific ID.
179    pub fn with_id(population_size: usize, strategy_len: usize, id: usize) -> Self {
180        let mut state = id as u64 * 1_000_003 + 42;
181        let target = Strategy::random(strategy_len, &mut state);
182        let agents: Vec<Agent> = (0..population_size)
183            .map(|_| Agent::random(strategy_len, &mut state))
184            .collect();
185        let best_strategy = Strategy::zeros(strategy_len);
186        Self {
187            agents,
188            target,
189            rng_state: state,
190            id,
191            best_fitness: 0.0,
192            best_strategy,
193        }
194    }
195
196    /// Set a specific target for the node.
197    pub fn with_target(mut self, target: Strategy) -> Self {
198        self.target = target;
199        self
200    }
201
202    /// Run one generation of local evolution.
203    pub fn evolve_step(&mut self) {
204        // Evaluate all agents
205        for agent in &mut self.agents {
206            agent.evaluate(&self.target);
207        }
208
209        // Sort by fitness descending
210        self.agents.sort_by(|a, b| b.fitness.partial_cmp(&a.fitness).unwrap_or(std::cmp::Ordering::Equal));
211
212        // Track best
213        if let Some(best) = self.agents.first() {
214            if best.fitness > self.best_fitness {
215                self.best_fitness = best.fitness;
216                self.best_strategy = best.strategy.clone();
217            }
218        }
219
220        // Tournament selection + mutation: replace bottom half with mutated copies of top half
221        let pop = self.agents.len();
222        if pop < 2 {
223            return;
224        }
225        let half = pop / 2;
226        let top: Vec<Agent> = self.agents[..half].to_vec();
227        for i in half..pop {
228            let mut child = top[i - half].clone();
229            child.strategy.mutate(&mut self.rng_state);
230            self.agents[i] = child;
231        }
232    }
233
234    /// Run N generations of local evolution.
235    pub fn evolve(&mut self, generations: usize) {
236        for _ in 0..generations {
237            self.evolve_step();
238        }
239    }
240
241    /// Get the average fitness across all agents.
242    pub fn avg_fitness(&self) -> f64 {
243        if self.agents.is_empty() {
244            return 0.0;
245        }
246        self.agents.iter().map(|a| a.fitness).sum::<f64>() / self.agents.len() as f64
247    }
248
249    /// Get the best fitness.
250    pub fn max_fitness(&self) -> f64 {
251        self.agents
252            .iter()
253            .map(|a| a.fitness)
254            .fold(0.0_f64, f64::max)
255    }
256
257    /// Get the strategy summary to share (the best strategy).
258    /// In a real system this would be perturbed for privacy.
259    pub fn strategy_summary(&self) -> Strategy {
260        self.best_strategy.clone()
261    }
262
263    /// Apply a federated strategy update: blend the global strategy into the population.
264    pub fn apply_federated_update(&mut self, global: &Strategy, blend_rate: f64) {
265        for agent in &mut self.agents {
266            for (i, v) in agent.strategy.values.iter_mut().enumerate() {
267                if i < global.len() {
268                    // With probability blend_rate, adopt the global value
269                    self.rng_state = self.rng_state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
270                    let prob = (self.rng_state >> 33) as f64 / (u32::MAX as f64);
271                    if prob < blend_rate {
272                        *v = global.values[i];
273                    }
274                }
275            }
276        }
277    }
278}
279
280// ---------------------------------------------------------------------------
281// Aggregation
282// ---------------------------------------------------------------------------
283
284/// Method for aggregating strategies across nodes.
285#[derive(Clone, Copy, Debug, PartialEq)]
286pub enum AggregationMethod {
287    /// Each position takes the value most common across nodes.
288    MajorityVote,
289    /// Nodes contribute proportionally to their fitness scores.
290    WeightedAverage,
291    /// Adopt the strategy from the highest-fitness node.
292    BestOf,
293}
294
295/// Aggregates strategy summaries from multiple nodes.
296pub struct Aggregator;
297
298impl Aggregator {
299    /// Aggregate strategies using the specified method.
300    pub fn aggregate(
301        summaries: &[(Strategy, f64)], // (strategy, fitness) per node
302        method: AggregationMethod,
303    ) -> Strategy {
304        if summaries.is_empty() {
305            return Strategy::zeros(0);
306        }
307        let len = summaries[0].0.len();
308        match method {
309            AggregationMethod::MajorityVote => Self::majority_vote(summaries, len),
310            AggregationMethod::WeightedAverage => Self::weighted_average(summaries, len),
311            AggregationMethod::BestOf => Self::best_of(summaries),
312        }
313    }
314
315    fn majority_vote(summaries: &[(Strategy, f64)], len: usize) -> Strategy {
316        let mut result = Vec::with_capacity(len);
317        for i in 0..len {
318            let mut counts = [0usize; 3]; // neg, zero, pos
319            for (s, _) in summaries {
320                if i < s.len() {
321                    match s.values[i] {
322                        Ternary::Neg => counts[0] += 1,
323                        Ternary::Zero => counts[1] += 1,
324                        Ternary::Pos => counts[2] += 1,
325                    }
326                }
327            }
328            let best = counts.iter().enumerate().max_by_key(|(_, &c)| c).map(|(i, _)| i).unwrap_or(1);
329            result.push(match best {
330                0 => Ternary::Neg,
331                2 => Ternary::Pos,
332                _ => Ternary::Zero,
333            });
334        }
335        Strategy { values: result }
336    }
337
338    fn weighted_average(summaries: &[(Strategy, f64)], len: usize) -> Strategy {
339        let total_weight: f64 = summaries.iter().map(|(_, f)| f).sum();
340        if total_weight <= 0.0 {
341            return Strategy::zeros(len);
342        }
343        let mut result = Vec::with_capacity(len);
344        for i in 0..len {
345            let mut weighted_sum = 0.0;
346            for (s, f) in summaries {
347                if i < s.len() {
348                    weighted_sum += s.values[i].as_f64() * f;
349                }
350            }
351            let avg = weighted_sum / total_weight;
352            result.push(if avg > 0.33 {
353                Ternary::Pos
354            } else if avg < -0.33 {
355                Ternary::Neg
356            } else {
357                Ternary::Zero
358            });
359        }
360        Strategy { values: result }
361    }
362
363    fn best_of(summaries: &[(Strategy, f64)]) -> Strategy {
364        summaries
365            .iter()
366            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
367            .map(|(s, _)| s.clone())
368            .unwrap_or_else(|| Strategy::zeros(0))
369    }
370}
371
372// ---------------------------------------------------------------------------
373// Privacy Budget
374// ---------------------------------------------------------------------------
375
376/// Tracks how much information has been shared (differential privacy style).
377#[derive(Clone, Debug)]
378pub struct PrivacyBudget {
379    /// Total epsilon available.
380    pub total_epsilon: f64,
381    /// Epsilon spent so far.
382    pub spent: f64,
383}
384
385impl PrivacyBudget {
386    /// Create a new privacy budget with the given total epsilon.
387    pub fn new(total_epsilon: f64) -> Self {
388        Self {
389            total_epsilon,
390            spent: 0.0,
391        }
392    }
393
394    /// Whether there is remaining budget.
395    pub fn has_budget(&self) -> bool {
396        self.spent < self.total_epsilon
397    }
398
399    /// Remaining epsilon.
400    pub fn remaining(&self) -> f64 {
401        (self.total_epsilon - self.spent).max(0.0)
402    }
403
404    /// Spend some epsilon. Returns false if insufficient budget.
405    pub fn spend(&mut self, epsilon: f64) -> bool {
406        if self.spent + epsilon > self.total_epsilon {
407            return false;
408        }
409        self.spent += epsilon;
410        true
411    }
412
413    /// Fraction of budget spent.
414    pub fn fraction_spent(&self) -> f64 {
415        if self.total_epsilon <= 0.0 {
416            1.0
417        } else {
418            self.spent / self.total_epsilon
419        }
420    }
421
422    /// Reset the budget.
423    pub fn reset(&mut self) {
424        self.spent = 0.0;
425    }
426}
427
428// ---------------------------------------------------------------------------
429// Federated Round
430// ---------------------------------------------------------------------------
431
432/// One round of federated aggregation.
433#[derive(Clone, Debug)]
434pub struct FederatedRound {
435    /// Round number (0-indexed).
436    pub round_number: usize,
437    /// Strategy summaries from each node (with fitness).
438    pub summaries: Vec<(Strategy, f64)>,
439    /// Aggregated global strategy.
440    pub global_strategy: Strategy,
441    /// Privacy epsilon spent this round.
442    pub epsilon_spent: f64,
443    /// Per-node fitness before aggregation.
444    pub pre_aggregation_fitness: Vec<f64>,
445    /// Per-node fitness after aggregation.
446    pub post_aggregation_fitness: Vec<f64>,
447}
448
449impl FederatedRound {
450    /// Execute one federated round.
451    pub fn execute(
452        nodes: &mut [Node],
453        round_number: usize,
454        method: AggregationMethod,
455        epsilon: f64,
456        privacy: &mut PrivacyBudget,
457    ) -> Option<Self> {
458        // Check privacy budget
459        if !privacy.has_budget() || privacy.remaining() < epsilon {
460            return None;
461        }
462
463        // Collect summaries
464        let pre_fitness: Vec<f64> = nodes.iter().map(|n| n.max_fitness()).collect();
465        let summaries: Vec<(Strategy, f64)> = nodes
466            .iter()
467            .map(|n| (n.strategy_summary(), n.best_fitness))
468            .collect();
469
470        // Aggregate
471        let global = Aggregator::aggregate(&summaries, method);
472
473        // Spend privacy budget
474        privacy.spend(epsilon);
475
476        // Apply update to all nodes (blend rate 0.3)
477        for node in nodes.iter_mut() {
478            node.apply_federated_update(&global, 0.3);
479        }
480
481        // Re-evaluate to get post-aggregation fitness
482        let post_fitness: Vec<f64> = {
483            nodes.iter_mut().for_each(|n| n.evolve(1));
484            nodes.iter().map(|n| n.max_fitness()).collect()
485        };
486
487        Some(Self {
488            round_number,
489            summaries,
490            global_strategy: global,
491            epsilon_spent: epsilon,
492            pre_aggregation_fitness: pre_fitness,
493            post_aggregation_fitness: post_fitness,
494        })
495    }
496}
497
498// ---------------------------------------------------------------------------
499// Federated Experiment Config
500// ---------------------------------------------------------------------------
501
502/// Configuration for a federated experiment.
503#[derive(Clone, Debug)]
504pub struct FederatedConfig {
505    /// Number of federated rounds.
506    pub rounds: usize,
507    /// Number of local evolution generations per round.
508    pub local_generations: usize,
509    /// Aggregation method to use.
510    pub aggregator: AggregationMethod,
511    /// Epsilon spent per round.
512    pub epsilon_per_round: f64,
513    /// Total privacy budget.
514    pub total_epsilon: f64,
515}
516
517impl Default for FederatedConfig {
518    fn default() -> Self {
519        Self {
520            rounds: 20,
521            local_generations: 10,
522            aggregator: AggregationMethod::WeightedAverage,
523            epsilon_per_round: 0.1,
524            total_epsilon: 5.0,
525        }
526    }
527}
528
529// ---------------------------------------------------------------------------
530// Federation Result
531// ---------------------------------------------------------------------------
532
533/// Structured result from a federated experiment.
534#[derive(Clone, Debug)]
535pub struct FederationResult {
536    /// Fitness history per node, per round.
537    pub per_node_fitness: Vec<Vec<f64>>,
538    /// Global strategy fitness at each round (average across nodes).
539    pub global_fitness_history: Vec<f64>,
540    /// Total privacy epsilon spent.
541    pub privacy_spent: f64,
542    /// Number of rounds completed.
543    pub rounds_completed: usize,
544    /// Number of nodes.
545    pub num_nodes: usize,
546    /// Per-round details.
547    pub rounds: Vec<FederatedRound>,
548    /// Final best strategy.
549    pub final_global_strategy: Strategy,
550    /// Whether the experiment completed all rounds.
551    pub completed: bool,
552}
553
554impl FederationResult {
555    /// Final global fitness (average of per-node fitness at last round).
556    pub fn global_fitness(&self) -> f64 {
557        self.global_fitness_history.last().copied().unwrap_or(0.0)
558    }
559
560    /// Total privacy spent.
561    pub fn privacy_spent(&self) -> f64 {
562        self.privacy_spent
563    }
564
565    /// Whether the experiment converged (fitness > 0.9 in last round).
566    pub fn converged(&self) -> bool {
567        self.global_fitness() > 0.9
568    }
569
570    /// Best-performing node index.
571    pub fn best_node(&self) -> usize {
572        self.per_node_fitness
573            .iter()
574            .enumerate()
575            .map(|(i, history)| (i, history.last().copied().unwrap_or(0.0)))
576            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
577            .map(|(i, _)| i)
578            .unwrap_or(0)
579    }
580
581    /// Summary string.
582    pub fn summary(&self) -> String {
583        format!(
584            "FederationResult: {} rounds, {} nodes, global_fitness={:.3}, privacy_spent={:.2}ε, converged={}",
585            self.rounds_completed,
586            self.num_nodes,
587            self.global_fitness(),
588            self.privacy_spent,
589            self.converged(),
590        )
591    }
592}
593
594// ---------------------------------------------------------------------------
595// Federated Experiment
596// ---------------------------------------------------------------------------
597
598/// Orchestrates a federated learning experiment.
599pub struct FederatedExperiment;
600
601impl FederatedExperiment {
602    /// Run a federated experiment with the given nodes and config.
603    pub fn run(mut nodes: Vec<Node>, config: FederatedConfig) -> FederationResult {
604        let num_nodes = nodes.len();
605        let mut privacy = PrivacyBudget::new(config.total_epsilon);
606        let mut per_node_fitness: Vec<Vec<f64>> = vec![vec![]; num_nodes];
607        let mut global_fitness_history = Vec::new();
608        let mut rounds_completed = 0;
609        let mut round_records = Vec::new();
610        let mut final_strategy = Strategy::zeros(0);
611
612        for round in 0..config.rounds {
613            // Local evolution
614            for node in nodes.iter_mut() {
615                node.evolve(config.local_generations);
616            }
617
618            // Federated round
619            let fed_round = FederatedRound::execute(
620                &mut nodes,
621                round,
622                config.aggregator,
623                config.epsilon_per_round,
624                &mut privacy,
625            );
626
627            match fed_round {
628                Some(fr) => {
629                    final_strategy = fr.global_strategy.clone();
630                    // Record fitness
631                    for (i, node) in nodes.iter().enumerate() {
632                        per_node_fitness[i].push(node.max_fitness());
633                    }
634                    let avg = nodes.iter().map(|n| n.max_fitness()).sum::<f64>() / num_nodes as f64;
635                    global_fitness_history.push(avg);
636                    rounds_completed += 1;
637                    round_records.push(fr);
638                }
639                None => {
640                    // Privacy budget exhausted — record final fitness without aggregation
641                    for (i, node) in nodes.iter().enumerate() {
642                        per_node_fitness[i].push(node.max_fitness());
643                    }
644                    let avg = nodes.iter().map(|n| n.max_fitness()).sum::<f64>() / num_nodes as f64;
645                    global_fitness_history.push(avg);
646                    rounds_completed += 1;
647                    break;
648                }
649            }
650        }
651
652        FederationResult {
653            per_node_fitness,
654            global_fitness_history,
655            privacy_spent: privacy.spent,
656            rounds_completed,
657            num_nodes,
658            rounds: round_records,
659            final_global_strategy: final_strategy,
660            completed: rounds_completed == config.rounds,
661        }
662    }
663}
664
665// ---------------------------------------------------------------------------
666// Tests
667// ---------------------------------------------------------------------------
668
669#[cfg(test)]
670mod tests {
671    use super::*;
672
673    #[test]
674    fn test_ternary_from_i8() {
675        assert_eq!(Ternary::from_i8(-5), Ternary::Neg);
676        assert_eq!(Ternary::from_i8(-1), Ternary::Neg);
677        assert_eq!(Ternary::from_i8(0), Ternary::Zero);
678        assert_eq!(Ternary::from_i8(1), Ternary::Pos);
679        assert_eq!(Ternary::from_i8(42), Ternary::Pos);
680    }
681
682    #[test]
683    fn test_ternary_conversions() {
684        assert_eq!(Ternary::Neg.as_i8(), -1);
685        assert_eq!(Ternary::Zero.as_i8(), 0);
686        assert_eq!(Ternary::Pos.as_i8(), 1);
687        assert_eq!(Ternary::Neg.as_f64(), -1.0);
688        assert_eq!(Ternary::Zero.as_f64(), 0.0);
689        assert_eq!(Ternary::Pos.as_f64(), 1.0);
690    }
691
692    #[test]
693    fn test_ternary_display() {
694        assert_eq!(format!("{}", Ternary::Neg), "-1");
695        assert_eq!(format!("{}", Ternary::Zero), "0");
696        assert_eq!(format!("{}", Ternary::Pos), "1");
697    }
698
699    #[test]
700    fn test_ternary_random() {
701        let mut state = 12345u64;
702        for _ in 0..100 {
703            let v = Ternary::random(&mut state);
704            assert!(v == Ternary::Neg || v == Ternary::Zero || v == Ternary::Pos);
705        }
706    }
707
708    #[test]
709    fn test_strategy_random_length() {
710        let mut state = 42u64;
711        let s = Strategy::random(10, &mut state);
712        assert_eq!(s.len(), 10);
713        assert!(!s.is_empty());
714    }
715
716    #[test]
717    fn test_strategy_zeros() {
718        let s = Strategy::zeros(5);
719        assert_eq!(s.len(), 5);
720        assert!(s.values.iter().all(|v| *v == Ternary::Zero));
721    }
722
723    #[test]
724    fn test_strategy_fitness_perfect() {
725        let target = Strategy {
726            values: vec![Ternary::Pos, Ternary::Neg, Ternary::Zero],
727        };
728        let s = target.clone();
729        assert!((s.fitness_against(&target) - 1.0).abs() < 1e-9);
730    }
731
732    #[test]
733    fn test_strategy_fitness_zero() {
734        let target = Strategy {
735            values: vec![Ternary::Pos, Ternary::Pos, Ternary::Pos],
736        };
737        let s = Strategy {
738            values: vec![Ternary::Neg, Ternary::Neg, Ternary::Neg],
739        };
740        assert!((s.fitness_against(&target) - 0.0).abs() < 1e-9);
741    }
742
743    #[test]
744    fn test_strategy_fitness_partial() {
745        let target = Strategy {
746            values: vec![Ternary::Pos, Ternary::Neg, Ternary::Zero, Ternary::Pos],
747        };
748        let s = Strategy {
749            values: vec![Ternary::Pos, Ternary::Zero, Ternary::Zero, Ternary::Neg],
750        };
751        assert!((s.fitness_against(&target) - 0.5).abs() < 1e-9);
752    }
753
754    #[test]
755    fn test_strategy_mutate_changes_something() {
756        let mut state = 99u64;
757        let original = Strategy {
758            values: vec![Ternary::Zero; 20],
759        };
760        let mut mutated = original.clone();
761        // Mutate several times to ensure change
762        for _ in 0..5 {
763            mutated.mutate(&mut state);
764        }
765        assert_ne!(original, mutated);
766    }
767
768    #[test]
769    fn test_node_creation() {
770        let node = Node::new(50, 10);
771        assert_eq!(node.agents.len(), 50);
772        assert_eq!(node.target.len(), 10);
773        assert_eq!(node.id, 0);
774    }
775
776    #[test]
777    fn test_node_evolution_improves_fitness() {
778        let mut node = Node::new(100, 20);
779        // Run many generations
780        node.evolve(100);
781        // Should have improved from initial 0
782        assert!(node.best_fitness > 0.0);
783    }
784
785    #[test]
786    fn test_node_avg_fitness() {
787        let mut node = Node::new(50, 10);
788        node.evolve(5);
789        let avg = node.avg_fitness();
790        assert!(avg >= 0.0 && avg <= 1.0);
791    }
792
793    #[test]
794    fn test_aggregator_majority_vote() {
795        let s1 = Strategy {
796            values: vec![Ternary::Pos, Ternary::Neg],
797        };
798        let s2 = Strategy {
799            values: vec![Ternary::Pos, Ternary::Zero],
800        };
801        let s3 = Strategy {
802            values: vec![Ternary::Neg, Ternary::Neg],
803        };
804        let result = Aggregator::aggregate(
805            &[(s1, 0.5), (s2, 0.5), (s3, 0.5)],
806            AggregationMethod::MajorityVote,
807        );
808        assert_eq!(result.values[0], Ternary::Pos); // 2 pos, 1 neg
809        assert_eq!(result.values[1], Ternary::Neg); // 2 neg-ish (neg + zero tie → neg in majority)
810    }
811
812    #[test]
813    fn test_aggregator_weighted_average() {
814        let s1 = Strategy {
815            values: vec![Ternary::Pos],
816        };
817        let s2 = Strategy {
818            values: vec![Ternary::Neg],
819        };
820        // s1 has much higher weight
821        let result = Aggregator::aggregate(
822            &[(s1, 10.0), (s2, 1.0)],
823            AggregationMethod::WeightedAverage,
824        );
825        assert_eq!(result.values[0], Ternary::Pos);
826    }
827
828    #[test]
829    fn test_aggregator_best_of() {
830        let s1 = Strategy {
831            values: vec![Ternary::Pos],
832        };
833        let s2 = Strategy {
834            values: vec![Ternary::Neg],
835        };
836        let result = Aggregator::aggregate(
837            &[(s1.clone(), 0.3), (s2.clone(), 0.9)],
838            AggregationMethod::BestOf,
839        );
840        assert_eq!(result.values[0], Ternary::Neg); // s2 had higher fitness
841    }
842
843    #[test]
844    fn test_aggregator_empty() {
845        let result = Aggregator::aggregate(&[], AggregationMethod::MajorityVote);
846        assert!(result.is_empty());
847    }
848
849    #[test]
850    fn test_privacy_budget_basic() {
851        let mut pb = PrivacyBudget::new(1.0);
852        assert!(pb.has_budget());
853        assert!((pb.remaining() - 1.0).abs() < 1e-9);
854        assert!(pb.spend(0.5));
855        assert!((pb.remaining() - 0.5).abs() < 1e-9);
856        assert!(pb.has_budget());
857        assert!(pb.spend(0.5));
858        assert!(!pb.has_budget());
859        assert!(!pb.spend(0.1)); // No budget left
860    }
861
862    #[test]
863    fn test_privacy_budget_fraction() {
864        let mut pb = PrivacyBudget::new(2.0);
865        pb.spend(0.5);
866        assert!((pb.fraction_spent() - 0.25).abs() < 1e-9);
867        pb.reset();
868        assert!((pb.fraction_spent() - 0.0).abs() < 1e-9);
869    }
870
871    #[test]
872    fn test_federated_round_execution() {
873        let mut nodes = vec![
874            Node::with_id(30, 8, 0),
875            Node::with_id(30, 8, 1),
876            Node::with_id(30, 8, 2),
877        ];
878        let mut privacy = PrivacyBudget::new(10.0);
879        // Pre-evolve so agents have some fitness
880        for node in nodes.iter_mut() {
881            node.evolve(10);
882        }
883        let round = FederatedRound::execute(&mut nodes, 0, AggregationMethod::MajorityVote, 0.5, &mut privacy);
884        assert!(round.is_some());
885        let r = round.unwrap();
886        assert_eq!(r.round_number, 0);
887        assert!((r.epsilon_spent - 0.5).abs() < 1e-9);
888        assert_eq!(r.summaries.len(), 3);
889    }
890
891    #[test]
892    fn test_federated_round_privacy_exhausted() {
893        let mut nodes = vec![Node::with_id(20, 5, 0)];
894        let mut privacy = PrivacyBudget::new(0.1);
895        let round = FederatedRound::execute(&mut nodes, 0, AggregationMethod::BestOf, 0.5, &mut privacy);
896        // 0.1 remaining < 0.5 needed
897        assert!(round.is_none());
898    }
899
900    #[test]
901    fn test_full_experiment() {
902        let nodes: Vec<Node> = (0..3).map(|i| Node::with_id(40, 8, i)).collect();
903        let config = FederatedConfig {
904            rounds: 10,
905            local_generations: 5,
906            aggregator: AggregationMethod::WeightedAverage,
907            epsilon_per_round: 0.2,
908            total_epsilon: 10.0,
909        };
910        let result = FederatedExperiment::run(nodes, config);
911        assert_eq!(result.num_nodes, 3);
912        assert!(result.rounds_completed > 0);
913        assert!(!result.global_fitness_history.is_empty());
914        assert!(result.privacy_spent > 0.0);
915        assert_eq!(result.rounds.len(), result.rounds_completed);
916        println!("{}", result.summary());
917    }
918
919    #[test]
920    fn test_experiment_privacy_limited() {
921        // Very tight budget — should stop early
922        let nodes: Vec<Node> = (0..2).map(|i| Node::with_id(20, 5, i)).collect();
923        let config = FederatedConfig {
924            rounds: 100,
925            local_generations: 3,
926            aggregator: AggregationMethod::MajorityVote,
927            epsilon_per_round: 1.0,
928            total_epsilon: 3.0,
929        };
930        let result = FederatedExperiment::run(nodes, config);
931        assert!(result.rounds_completed <= 4); // 3 rounds max with budget 3.0
932        assert!(!result.completed);
933    }
934
935    #[test]
936    fn test_convergence_with_shared_target() {
937        // All nodes share the same target — should converge well
938        let target = Strategy {
939            values: vec![Ternary::Pos; 10],
940        };
941        let nodes: Vec<Node> = (0..5)
942            .map(|i| Node::with_id(50, 10, i).with_target(target.clone()))
943            .collect();
944        let config = FederatedConfig {
945            rounds: 30,
946            local_generations: 10,
947            aggregator: AggregationMethod::WeightedAverage,
948            epsilon_per_round: 0.1,
949            total_epsilon: 10.0,
950        };
951        let result = FederatedExperiment::run(nodes, config);
952        // With shared target and enough rounds, should converge
953        assert!(result.global_fitness() > 0.5, "Expected fitness > 0.5, got {}", result.global_fitness());
954    }
955
956    #[test]
957    fn test_federation_result_best_node() {
958        let nodes: Vec<Node> = (0..4).map(|i| Node::with_id(30, 8, i)).collect();
959        let config = FederatedConfig::default();
960        let result = FederatedExperiment::run(nodes, config);
961        let best = result.best_node();
962        assert!(best < result.num_nodes);
963    }
964
965    #[test]
966    fn test_node_with_target() {
967        let target = Strategy {
968            values: vec![Ternary::Neg, Ternary::Pos, Ternary::Zero],
969        };
970        let node = Node::new(10, 3).with_target(target.clone());
971        assert_eq!(node.target, target);
972    }
973}