use tonic::{Request, Response, Status};
use tokio::sync::RwLock;
use std::time::Instant;
use futures::future::join_all;
use std::sync::Arc;
use crate::types::RoleId;
use crate::value::Value as RustValue;
use crate::value::sorted_map_entries;
use crate::pb::value::Kind;
use crate::pb::fizz_bee_mbt_plugin_service_server::FizzBeeMbtPluginService;
use crate::pb::{Value as ProtoValue, MapValue, MapEntry, ListValue, SetValue, Arg as ProtoArg};
use crate::error::MbtError;
use crate::types::Arg as RustArg; use std::collections::{HashMap, HashSet};
use crate::pb::{
InitRequest, InitResponse,
CleanupRequest, CleanupResponse,
ExecuteActionRequest, ExecuteActionResponse,
ExecuteActionSequencesRequest, ExecuteActionSequencesResponse,
Interval, RoleRef, Status as ProtoStatus, StatusCode,
ActionSequence, ActionSequenceResult, ExecOptions,
};
use crate::traits::{Model, DispatchModel};
fn mbt_error_to_status(err: MbtError) -> Status {
Status::internal(format!("MBT Execution Error: {}", err))
}
fn proto_ref_to_role_id(proto_ref: RoleRef) -> Result<RoleId, MbtError> {
Ok(RoleId {
role_name: proto_ref.role_name,
index: proto_ref.role_id as i32
})
}
fn role_id_to_proto_ref(role_id: crate::types::RoleId) -> RoleRef {
RoleRef {
role_name: role_id.role_name,
role_id: role_id.index as i32,
}
}
fn rust_value_to_proto_value(rust_value: RustValue) -> ProtoValue {
match rust_value {
RustValue::Int(v) => ProtoValue {
kind: Some(Kind::IntValue(v)),
},
RustValue::Str(s) => ProtoValue {
kind: Some(Kind::StrValue(s)),
},
RustValue::Bool(b) => ProtoValue {
kind: Some(Kind::BoolValue(b)),
},
RustValue::Map(map) => {
let entries = sorted_map_entries(&map)
.into_iter()
.map(|(k, v)| MapEntry {
key: Some(rust_value_to_proto_value(k.clone())),
value: Some(rust_value_to_proto_value(v.clone())),
})
.collect();
let map_value = MapValue { entries };
ProtoValue {
kind: Some(Kind::MapValue(map_value)),
}
}
RustValue::List(list) => {
let items = list
.into_iter()
.map(rust_value_to_proto_value) .collect();
let list_value = ListValue { items };
ProtoValue {
kind: Some(Kind::ListValue(list_value)),
}
}
RustValue::Set(set) => {
_ = SetValue::default();
let items = set
.into_iter()
.map(rust_value_to_proto_value)
.collect();
let list_value = ListValue { items };
ProtoValue {
kind: Some(Kind::ListValue(list_value)),
}
}
RustValue::None => {
ProtoValue::default()
}
}
}
fn proto_value_to_rust_value(proto_value: ProtoValue) -> Result<RustValue, MbtError> {
let kind = match proto_value.kind {
Some(k) => k,
None => return Ok(RustValue::None),
};
match kind {
Kind::StrValue(s) => Ok(RustValue::Str(s)),
Kind::IntValue(v) => Ok(RustValue::Int(v)),
Kind::BoolValue(b) => Ok(RustValue::Bool(b)),
Kind::MapValue(MapValue { entries }) => {
let mut map = HashMap::new();
for MapEntry { key, value } in entries {
let key = key.ok_or_else(|| MbtError::Other("MapEntry key is missing.".into()))?;
let value = value.ok_or_else(|| MbtError::Other("MapEntry value is missing.".into()))?;
let rust_key = proto_value_to_rust_value(key)?;
let rust_value = proto_value_to_rust_value(value)?;
map.insert(rust_key, rust_value);
}
Ok(RustValue::Map(map))
}
Kind::ListValue(ListValue { items }) => {
let list: Result<Vec<RustValue>, MbtError> = items
.into_iter()
.map(proto_value_to_rust_value) .collect();
Ok(RustValue::List(list?))
}
_ => {
Err(MbtError::NotImplemented("Unsupported Protobuf Value kind for conversion.".into()))
}
}
}
pub fn proto_args_to_rust_args(proto_args: Vec<ProtoArg>) -> Result<Vec<RustArg>, MbtError> {
let mut rust_args = Vec::with_capacity(proto_args.len());
for proto_arg in proto_args {
let proto_value = proto_arg.value.ok_or_else(|| {
MbtError::Other(format!("Value is missing for argument '{}'.", proto_arg.name))
})?;
let rust_value = proto_value_to_rust_value(proto_value)?;
let rust_arg = RustArg {
name: proto_arg.name,
value: rust_value,
};
rust_args.push(rust_arg);
}
Ok(rust_args)
}
fn proto_status_ok() -> ProtoStatus {
ProtoStatus {
code: StatusCode::StatusOk as i32,
message: "OK".to_string(),
}
}
fn mbt_error_to_proto_status(err: MbtError) -> ProtoStatus {
if err.is_not_implemented() {
return ProtoStatus {
code: StatusCode::StatusNotImplemented as i32,
message: format!("Not Implemented: {}", err),
};
}
ProtoStatus {
code: StatusCode::StatusExecutionFailed as i32,
message: format!("Execution Failed: {}", err),
}
}
#[derive(Debug, Clone)]
struct ExecuteActionCommand {
pub request: ExecuteActionRequest,
pub _exec_options: ExecOptions,
pub args : Vec<RustArg>,
pub start_time: Option<Instant>,
pub end_time: Option<Instant>,
pub return_value: Option<RustValue>,
pub error: Option<MbtError>,
}
type ActionSequenceCommandBundle = Vec<ExecuteActionCommand>;
pub struct FizzBeeServiceImpl<D>
where
D: Model + DispatchModel + Send + Sync + 'static,
{
dispatcher: Arc<RwLock<D>>,
base_instant: Instant,
}
impl<D> FizzBeeServiceImpl<D>
where
D: Model + DispatchModel + Send + Sync + 'static,
{
pub fn new(dispatcher: D) -> Self {
FizzBeeServiceImpl {
dispatcher: Arc::new(RwLock::new(dispatcher)),
base_instant: Instant::now(),
}
}
fn get_and_convert_roles(
dispatcher: &tokio::sync::RwLockWriteGuard<'_, D>
) -> Result<Vec<RoleRef>, MbtError> {
let rust_role_ids = dispatcher.get_roles()?;
let proto_role_refs: Vec<RoleRef> = rust_role_ids.into_iter()
.map(role_id_to_proto_ref)
.collect();
Ok(proto_role_refs)
}
fn nanos_since_base(&self, instant: Instant) -> i64 {
instant.duration_since(self.base_instant).as_nanos() as i64
}
fn deserialize_sequences(
req: Request<ExecuteActionSequencesRequest>
) -> Result<Vec<ActionSequenceCommandBundle>, MbtError> {
let proto_sequences = req.into_inner().action_sequence;
let mut all_bundles = Vec::with_capacity(proto_sequences.len());
for ActionSequence { requests, options } in proto_sequences {
let options = options.unwrap_or_default();
let mut bundle = Vec::with_capacity(requests.len());
for mut request in requests {
let proto_args = std::mem::take(&mut request.args);
let rust_args = proto_args_to_rust_args(proto_args)?;
bundle.push(ExecuteActionCommand {
request,
_exec_options: options.clone(),
args: rust_args.clone(),
start_time: None,
end_time: None,
return_value: None,
error: None,
});
}
all_bundles.push(bundle);
}
Ok(all_bundles)
}
async fn execute_sequences_concurrent(
&self,
all_bundles: Vec<ActionSequenceCommandBundle>, ) -> Result<Vec<ActionSequenceCommandBundle>, MbtError> {
let mut futures = Vec::with_capacity(all_bundles.len());
for (seq_idx, mut sequence) in all_bundles.into_iter().enumerate() {
let dispatcher_arc = self.dispatcher.clone();
let future = tokio::spawn(async move {
for cmd in sequence.iter_mut() {
let action_name = cmd.request.action_name.clone();
let role_ref = cmd.request.role.clone().unwrap_or_default();
let role_id = proto_ref_to_role_id(role_ref)
.unwrap_or_else(|_| RoleId::default());
let dispatcher_read_lock = dispatcher_arc.read().await;
let start_time = Instant::now();
let (result, err) = match dispatcher_read_lock.execute(
&role_id, &action_name,
&cmd.args ).await {
Ok(val) => (Some(val), None),
Err(e) => (None, Some(e)),
};
let end_time = Instant::now();
drop(dispatcher_read_lock);
cmd.start_time = Some(start_time);
cmd.end_time = Some(end_time);
cmd.return_value = result;
cmd.error = err;
if let Some(ref e) = cmd.error {
if !e.is_not_implemented() {
return Err(e.clone());
}
}
}
Ok((seq_idx, sequence))
});
futures.push(future);
}
let results = join_all(futures).await;
let mut final_bundles: Vec<Option<ActionSequenceCommandBundle>> =
std::iter::repeat(None).take(results.len()).collect();
for res in results {
match res {
Ok(inner_res) => match inner_res {
Ok((idx, sequence)) => {
final_bundles[idx] = Some(sequence);
},
Err(e) => return Err(e),
},
Err(e) => return Err(MbtError::other(format!("Action sequence task failed: {}", e))),
}
}
let bundles = final_bundles.into_iter().map(|o| o.expect("Sequence missing after join_all. Logic error in indexing/processing.")).collect();
Ok(bundles) }
fn serialize_sequence_results(
&self,
all_bundles: Vec<ActionSequenceCommandBundle>,
) -> Result<Response<ExecuteActionSequencesResponse>, Status> {
let mut results = Vec::with_capacity(all_bundles.len());
for bundle in all_bundles {
let mut action_responses = Vec::with_capacity(bundle.len());
for cmd in bundle {
let exec_time = if let (Some(start), Some(end)) = (cmd.start_time, cmd.end_time) {
Some(Interval {
start_unix_nano: self.nanos_since_base(start),
end_unix_nano: self.nanos_since_base(end),
})
} else {
None
};
let (return_values, status) = match cmd.return_value {
Some(RustValue::None) => (vec![], proto_status_ok()),
Some(value) => {
let proto_value = rust_value_to_proto_value(value);
(vec![proto_value], proto_status_ok())
}
None => {
let error = cmd.error.unwrap_or_else(|| {
MbtError::other("Unknown error in sequence execution")
});
(vec![], mbt_error_to_proto_status(error))
}
};
let roles = vec![];
let role_states = vec![];
action_responses.push(ExecuteActionResponse {
return_values,
exec_time,
status: Some(status),
roles,
role_states,
});
}
results.push(ActionSequenceResult {
responses: action_responses,
});
}
let response = ExecuteActionSequencesResponse { results };
Ok(Response::new(response))
}
}
#[tonic::async_trait]
impl<D> FizzBeeMbtPluginService for FizzBeeServiceImpl<D>
where
D: Model + DispatchModel + Send + Sync + 'static,
{
async fn init(
&self,
_request: Request<InitRequest>,
) -> Result<Response<InitResponse>, Status> {
let mut dispatcher = self.dispatcher.write().await;
match dispatcher.init().await {
Ok(_) => {
let proto_role_refs = Self::get_and_convert_roles(&dispatcher)
.map_err(mbt_error_to_status)?;
let response = InitResponse {
status: Some(proto_status_ok()),
roles: proto_role_refs,
..Default::default()
};
Ok(Response::new(response))
}
Err(e) => Err(mbt_error_to_status(e)),
}
}
async fn cleanup(
&self,
_request: Request<CleanupRequest>,
) -> Result<Response<CleanupResponse>, Status> {
let mut model = self.dispatcher.write().await;
match model.cleanup().await {
Ok(_) => {
let response = CleanupResponse {
status: Some(proto_status_ok()),
..Default::default()
};
Ok(Response::new(response))
}
Err(e) => Err(mbt_error_to_status(e)),
}
}
async fn execute_action(
&self,
request: Request<ExecuteActionRequest>,
) -> Result<Response<ExecuteActionResponse>, Status> {
let req = request.into_inner();
let dispatcher_read_lock = self.dispatcher.read().await;
let role_id = proto_ref_to_role_id(req.role.ok_or_else(|| Status::invalid_argument("RoleRef is missing in request."))?)
.map_err(mbt_error_to_status)?;
let rust_args = proto_args_to_rust_args(req.args)
.map_err(mbt_error_to_status)?;
let result = dispatcher_read_lock.execute(
&role_id,
&req.action_name,
&rust_args
).await;
drop(dispatcher_read_lock);
let dispatcher_write_lock = self.dispatcher.write().await;
match result {
Ok(returned_value) => {
let proto_role_refs = Self::get_and_convert_roles(&dispatcher_write_lock)
.map_err(mbt_error_to_status)?;
let return_values = match returned_value {
crate::value::Value::None => {
vec![]
},
value => {
let proto_value = rust_value_to_proto_value(value);
vec![proto_value]
}
};
let response = ExecuteActionResponse {
return_values: return_values,
status: Some(proto_status_ok()),
roles: proto_role_refs,
..Default::default()
};
Ok(Response::new(response))
}
Err(e) => {
let response = ExecuteActionResponse {
status: Some(mbt_error_to_proto_status(e)),
..Default::default()
};
Ok(Response::new(response))
}
}
}
async fn execute_action_sequences(
&self,
request: Request<ExecuteActionSequencesRequest>,
) -> Result<Response<ExecuteActionSequencesResponse>, Status> {
let all_bundles = match Self::deserialize_sequences(request) {
Ok(bundles) => bundles,
Err(e) => return Err(mbt_error_to_status(e)),
};
let all_bundles = self.execute_sequences_concurrent(all_bundles)
.await
.map_err(mbt_error_to_status)?;
self.serialize_sequence_results(all_bundles)
}
}