moose 0.2.2

Encrypted learning and data processing framework
Documentation
//! gRPC-based choreography.

pub(crate) mod gen {
    tonic::include_proto!("moose_choreography");
}

use self::gen::choreography_server::{Choreography, ChoreographyServer};
use self::gen::{
    AbortComputationRequest, AbortComputationResponse, LaunchComputationRequest,
    LaunchComputationResponse, RetrieveResultsRequest, RetrieveResultsResponse,
};
use super::{NetworkingStrategy, StorageStrategy};
use crate::computation::{SessionId, Value};
use crate::execution::ExecutionContext;
use crate::execution::Identity;
use async_cell::sync::AsyncCell;
use async_trait::async_trait;
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};

#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)]
pub struct ComputationOutputs {
    pub outputs: HashMap<String, Value>,
    pub elapsed_time: Option<Duration>,
}

type ResultStores = DashMap<SessionId, Arc<AsyncCell<ComputationOutputs>>>;

pub struct GrpcChoreography {
    own_identity: Identity,
    choreographer: Option<String>,
    result_stores: Arc<ResultStores>,
    networking_strategy: NetworkingStrategy,
    storage_strategy: StorageStrategy,
}

impl GrpcChoreography {
    pub fn new(
        own_identity: Identity,
        choreographer: Option<String>,
        networking_strategy: NetworkingStrategy,
        storage_strategy: StorageStrategy,
    ) -> GrpcChoreography {
        GrpcChoreography {
            own_identity,
            choreographer,
            result_stores: Arc::new(ResultStores::default()),
            networking_strategy,
            storage_strategy,
        }
    }

    pub fn into_server(self) -> ChoreographyServer<impl Choreography> {
        ChoreographyServer::new(self)
    }
}

impl GrpcChoreography {
    fn check_choreographer<T>(&self, request: &tonic::Request<T>) -> Result<(), tonic::Status> {
        let choreographer = crate::grpc::extract_sender(request).map_err(|_e| {
            tonic::Status::new(
                tonic::Code::Aborted,
                "failed to extract sender identity".to_string(),
            )
        })?;

        match (&self.choreographer, choreographer) {
            (None, None) => Ok(()),
            (None, Some(_actual)) => Err(tonic::Status::new(
                tonic::Code::Aborted,
                "did not expect choreographer certificate".to_string(),
            )),
            (Some(_expected), None) => Err(tonic::Status::new(
                tonic::Code::Aborted,
                "expected choreographer certificate".to_string(),
            )),
            (Some(expected), Some(actual)) => {
                if expected != &actual {
                    Err(tonic::Status::new(
                        tonic::Code::Aborted,
                        "expected choreographer did not match actual".to_string(),
                    ))
                } else {
                    Ok(())
                }
            }
        }
    }
}

#[async_trait]
impl Choreography for GrpcChoreography {
    async fn launch_computation(
        &self,
        request: tonic::Request<LaunchComputationRequest>,
    ) -> Result<tonic::Response<LaunchComputationResponse>, tonic::Status> {
        tracing::info!("Launching computation");

        self.check_choreographer(&request)?;
        let request = request.into_inner();

        let session_id = bincode::deserialize::<SessionId>(&request.session_id).map_err(|_e| {
            tonic::Status::new(
                tonic::Code::Aborted,
                "failed to parse session id".to_string(),
            )
        })?;

        match self.result_stores.entry(session_id.clone()) {
            Entry::Occupied(_) => Err(tonic::Status::new(
                tonic::Code::Aborted,
                "session id exists already or inconsistent metric and result map".to_string(),
            )),
            Entry::Vacant(result_stores_entry) => {
                let result_cell = AsyncCell::shared();
                result_stores_entry.insert(result_cell);

                let computation = bincode::deserialize(&request.computation).map_err(|_e| {
                    tonic::Status::new(
                        tonic::Code::Aborted,
                        "failed to parse computation".to_string(),
                    )
                })?;

                let arguments = bincode::deserialize(&request.arguments).map_err(|_e| {
                    tonic::Status::new(
                        tonic::Code::Aborted,
                        "failed to parse arguments".to_string(),
                    )
                })?;

                let role_assignments =
                    bincode::deserialize(&request.role_assignment).map_err(|_e| {
                        tonic::Status::new(
                            tonic::Code::Aborted,
                            "failed to parse role assignment".to_string(),
                        )
                    })?;

                let own_identity = self.own_identity.clone();
                let networking = (self.networking_strategy)(session_id.clone());
                let storage = (self.storage_strategy)();
                let context = ExecutionContext::new(own_identity, networking, storage);

                let execution_start_timer = Instant::now();

                let (_handle, outputs) = context
                    .execute_computation(
                        session_id.clone(),
                        &computation,
                        arguments,
                        role_assignments,
                    )
                    .await
                    .map_err(|_e| {
                        tonic::Status::new(
                            tonic::Code::Aborted,
                            "failed launch computation".to_string(),
                        )
                    })?;

                let result_stores = Arc::clone(&self.result_stores);

                tokio::spawn(async move {
                    let mut results = HashMap::with_capacity(outputs.len());
                    for (output_name, output_value) in outputs {
                        let value = output_value.await.unwrap();
                        results.insert(output_name, value);
                    }
                    tracing::info!("Results ready, {:?}", results.keys());

                    let result_cell = result_stores
                        .get(&session_id)
                        .expect("session disappeared unexpectedly");

                    let execution_stop_timer = Instant::now();
                    let elapsed_time = execution_stop_timer.duration_since(execution_start_timer);
                    result_cell.set(ComputationOutputs {
                        outputs: results,
                        elapsed_time: Some(elapsed_time),
                    });
                });

                Ok(tonic::Response::new(LaunchComputationResponse::default()))
            }
        }
    }

    async fn abort_computation(
        &self,
        _request: tonic::Request<AbortComputationRequest>,
    ) -> Result<tonic::Response<AbortComputationResponse>, tonic::Status> {
        unimplemented!()
    }

    async fn retrieve_results(
        &self,
        request: tonic::Request<RetrieveResultsRequest>,
    ) -> Result<tonic::Response<RetrieveResultsResponse>, tonic::Status> {
        self.check_choreographer(&request)?;
        let request = request.into_inner();

        let session_id = bincode::deserialize::<SessionId>(&request.session_id).map_err(|_e| {
            tonic::Status::new(
                tonic::Code::Aborted,
                "failed to parse session id".to_string(),
            )
        })?;

        match self.result_stores.get(&session_id) {
            Some(results) => {
                let results = results.value().get().await;
                let values = bincode::serialize(&results).expect("failed to serialize results");

                Ok(tonic::Response::new(RetrieveResultsResponse { values }))
            }
            None => Err(tonic::Status::new(
                tonic::Code::NotFound,
                "unknown session id".to_string(),
            )),
        }
    }
}