smp-tee-runtime 0.1.0

Hardened minimal runtime for TEE-based federated aggregation
Documentation
use std::collections::HashMap;

use crate::aggregation::{federated_averaging, multi_krum};

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AggregationAlgorithm {
    FederatedAveraging,
    MultiKrum { byzantine_tolerance: usize },
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ComputationParams {
    pub algorithm: AggregationAlgorithm,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TeeError {
    NotInitialized,
    InvalidAllocationSize,
    InvalidPointer,
    InvalidInput(&'static str),
}

pub trait TeeGuard {
    fn initialize(&mut self) -> Result<(), TeeError>;
    fn allocate_memory(&mut self, size: usize) -> Result<*mut u8, TeeError>;
    fn write_data(&mut self, ptr: *mut u8, data: &[u8]) -> Result<(), TeeError>;
    fn execute_computation(
        &self,
        input_ptrs: &[*const u8],
        params: &ComputationParams,
    ) -> Result<Vec<u8>, TeeError>;
}

#[derive(Debug, Default)]
pub struct InMemoryTee {
    initialized: bool,
    allocations: HashMap<usize, Vec<u8>>,
}

impl InMemoryTee {
    fn read_vector(&self, ptr: *const u8) -> Result<Vec<f32>, TeeError> {
        let bytes = self
            .allocations
            .get(&(ptr as usize))
            .ok_or(TeeError::InvalidPointer)?;
        if bytes.len() % 4 != 0 {
            return Err(TeeError::InvalidInput(
                "payload length must be a multiple of 4",
            ));
        }

        Ok(bytes
            .chunks_exact(4)
            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
            .collect())
    }

    fn encode_vector(values: &[f32]) -> Vec<u8> {
        values
            .iter()
            .flat_map(|value| value.to_le_bytes())
            .collect::<Vec<u8>>()
    }
}

impl TeeGuard for InMemoryTee {
    fn initialize(&mut self) -> Result<(), TeeError> {
        self.initialized = true;
        Ok(())
    }

    fn allocate_memory(&mut self, size: usize) -> Result<*mut u8, TeeError> {
        if !self.initialized {
            return Err(TeeError::NotInitialized);
        }
        if size == 0 {
            return Err(TeeError::InvalidAllocationSize);
        }

        let mut allocation = vec![0_u8; size];
        let ptr = allocation.as_mut_ptr();
        self.allocations.insert(ptr as usize, allocation);
        Ok(ptr)
    }

    fn write_data(&mut self, ptr: *mut u8, data: &[u8]) -> Result<(), TeeError> {
        if !self.initialized {
            return Err(TeeError::NotInitialized);
        }

        let buffer = self
            .allocations
            .get_mut(&(ptr as usize))
            .ok_or(TeeError::InvalidPointer)?;

        if data.len() > buffer.len() {
            return Err(TeeError::InvalidAllocationSize);
        }

        buffer[..data.len()].copy_from_slice(data);
        Ok(())
    }

    fn execute_computation(
        &self,
        input_ptrs: &[*const u8],
        params: &ComputationParams,
    ) -> Result<Vec<u8>, TeeError> {
        if !self.initialized {
            return Err(TeeError::NotInitialized);
        }

        let vectors = input_ptrs
            .iter()
            .map(|ptr| self.read_vector(*ptr))
            .collect::<Result<Vec<_>, _>>()?;

        let result = match params.algorithm {
            AggregationAlgorithm::FederatedAveraging => federated_averaging(&vectors)
                .ok_or(TeeError::InvalidInput("invalid federated averaging input"))?,
            AggregationAlgorithm::MultiKrum {
                byzantine_tolerance,
            } => multi_krum(&vectors, byzantine_tolerance)
                .ok_or(TeeError::InvalidInput("invalid multi-krum input"))?,
        };

        Ok(Self::encode_vector(&result))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn to_bytes(values: &[f32]) -> Vec<u8> {
        values.iter().flat_map(|v| v.to_le_bytes()).collect()
    }

    fn to_f32(bytes: &[u8]) -> Vec<f32> {
        bytes
            .chunks_exact(4)
            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
            .collect()
    }

    #[test]
    fn tee_executes_federated_averaging() {
        let mut tee = InMemoryTee::default();
        tee.initialize().unwrap();

        let p1 = tee.allocate_memory(8).unwrap();
        let p2 = tee.allocate_memory(8).unwrap();

        tee.write_data(p1, &to_bytes(&[1.0, 3.0])).unwrap();
        tee.write_data(p2, &to_bytes(&[3.0, 5.0])).unwrap();

        let out = tee
            .execute_computation(
                &[p1.cast_const(), p2.cast_const()],
                &ComputationParams {
                    algorithm: AggregationAlgorithm::FederatedAveraging,
                },
            )
            .unwrap();

        assert_eq!(to_f32(&out), vec![2.0, 4.0]);
    }
}