use tonic::{Request, Response, Status};
use tokio::sync::Mutex;
use std::time::{Instant, Duration};
use tokio::time::sleep;
use futures::future::join_all; // To wait for concurrent tasks
use std::sync::Arc;
use crate::types::RoleId;
use crate::value::Value as RustValue;
use crate::value::sorted_map_entries;
use crate::pb::value::Kind;
use crate::pb::fizz_bee_mbt_plugin_service_server::FizzBeeMbtPluginService;
use crate::pb::{Value as ProtoValue, MapValue, MapEntry, ListValue};
use crate::pb::{
InitRequest, InitResponse,
CleanupRequest, CleanupResponse,
ExecuteActionRequest, ExecuteActionResponse,
ExecuteActionSequencesRequest, ExecuteActionSequencesResponse,
Interval, RoleRef, Status as ProtoStatus, StatusCode,
ActionSequence, ActionSequenceResult, ExecOptions,
};
use crate::traits::{Model, DispatchModel}; // <-- DEPENDS ONLY ON GENERIC L1 TRAITS
use crate::error::MbtError; // Your custom error type
fn mbt_error_to_status(err: MbtError) -> Status {
Status::internal(format!("MBT Execution Error: {}", err))
}
// 2. Convert Protobuf RoleRef to your Rust RoleId type.
fn proto_ref_to_role_id(proto_ref: RoleRef) -> Result<RoleId, MbtError> {
Ok(RoleId {
role_name: proto_ref.role_name,
index: proto_ref.role_id as i32
})
}
fn role_id_to_proto_ref(role_id: crate::types::RoleId) -> RoleRef {
RoleRef {
role_name: role_id.role_name,
// The proto expects i32, so we safely cast the u32 index.
role_id: role_id.index as i32,
}
}
/// Converts an internal Rust Value enum into a Protobuf Value message.
fn rust_value_to_proto_value(rust_value: RustValue) -> ProtoValue {
match rust_value {
RustValue::Int(v) => ProtoValue {
kind: Some(Kind::IntValue(v)),
},
RustValue::Str(s) => ProtoValue {
kind: Some(Kind::StrValue(s)),
},
RustValue::Bool(b) => ProtoValue {
kind: Some(Kind::BoolValue(b)),
},
// --- Compound Types ---
RustValue::Map(map) => {
let entries = sorted_map_entries(&map)
.into_iter()
.map(|(k, v)| MapEntry {
// Recursive call to convert key and value
key: Some(rust_value_to_proto_value(k.clone())),
value: Some(rust_value_to_proto_value(v.clone())),
})
.collect();
let map_value = MapValue { entries };
ProtoValue {
kind: Some(Kind::MapValue(map_value)),
}
}
RustValue::List(list) => {
let items = list
.into_iter()
.map(rust_value_to_proto_value) // Recursive call
.collect();
let list_value = ListValue { items };
ProtoValue {
kind: Some(Kind::ListValue(list_value)),
}
}
// --- Types Not Supported by Proto/Ignored ---
RustValue::Set(set) => {
// NOTE: Protobuf doesn't have a native Set type.
// We serialize it as a List, as the ordering doesn't matter for sets.
let items = set
.into_iter()
.map(rust_value_to_proto_value)
.collect();
let list_value = ListValue { items };
ProtoValue {
kind: Some(Kind::ListValue(list_value)),
}
}
RustValue::None => {
// NOTE: Protobuf oneof field requires one of the defined kinds.
// We can serialize `None` as an empty string or simply return the default/unspecified value.
// Returning default is often safest when a value is truly absent.
ProtoValue::default()
}
}
}
// --- Helper function to create a successful Proto Status ---
// FIX 1: Use the correct alias ProtoStatus in the function signature
fn proto_status_ok() -> ProtoStatus {
ProtoStatus { // FIX 2: Use the correct struct name
code: StatusCode::StatusOk as i32, // FIX 3: StatusCode is now in scope
message: "OK".to_string(),
}
}
// 5. Convert MbtError into a Protobuf Status object (for the response body)
fn mbt_error_to_proto_status(err: MbtError) -> ProtoStatus {
// NOTE: In a real system, you'd map error types to specific StatusCodes
ProtoStatus {
code: StatusCode::StatusExecutionFailed as i32,
message: format!("Execution Failed: {}", err),
}
}
// --- Internal Command Structure ---
/// Internal structure representing a single action command with space for execution results.
#[derive(Debug)]
struct ExecuteActionCommand {
pub request: ExecuteActionRequest,
pub exec_options: ExecOptions,
// Results
pub start_time: Option<Instant>,
pub end_time: Option<Instant>,
pub return_value: Option<RustValue>,
pub error: Option<MbtError>,
}
/// Internal structure representing a sequence of actions.
type ActionSequenceCommandBundle = Vec<ExecuteActionCommand>;
// --- The Generic Service Implementation Struct ---
// M: The Model type (for init/cleanup)
// D: The Dispatcher type (for execute)
pub struct FizzBeeServiceImpl<D>
where
D: Model + DispatchModel + Send + Sync + 'static,
{
// The Dispatcher is stored to handle ExecuteAction
dispatcher: Arc<Mutex<D>>,
base_instant: Instant,
}
impl<D> FizzBeeServiceImpl<D>
where
D: Model + DispatchModel + Send + Sync + 'static,
{
pub fn new(dispatcher: D) -> Self {
FizzBeeServiceImpl {
dispatcher: Arc::new(Mutex::new(dispatcher)),
base_instant: Instant::now(),
}
}
/// Calls get_roles on the dispatcher and converts Rust RoleId structs to Protobuf RoleRef messages.
fn get_and_convert_roles(
dispatcher: &mut tokio::sync::MutexGuard<'_, D>
) -> Result<Vec<RoleRef>, MbtError> {
// 1. Call the Dispatcher's get_roles method (Step 2 logic)
let rust_role_ids = dispatcher.get_roles()?;
// 2. Map Rust RoleId structs to Protobuf RoleRef messages (Step 3 logic)
let proto_role_refs: Vec<RoleRef> = rust_role_ids.into_iter()
.map(role_id_to_proto_ref)
.collect();
Ok(proto_role_refs)
}
/// Calculates monotonic time in nanoseconds since the service was created.
fn nanos_since_base(&self, instant: Instant) -> i64 {
instant.duration_since(self.base_instant).as_nanos() as i64
}
// --- Helper for ExecuteActionSequences ---
/// Deserializes the proto request into internal command bundles.
fn deserialize_sequences(
req: Request<ExecuteActionSequencesRequest>
) -> Result<Vec<ActionSequenceCommandBundle>, MbtError> {
let proto_sequences = req.into_inner().action_sequence;
let mut all_bundles = Vec::with_capacity(proto_sequences.len());
for ActionSequence { requests, options } in proto_sequences {
let options = options.unwrap_or_default();
let mut bundle = Vec::with_capacity(requests.len());
for request in requests {
bundle.push(ExecuteActionCommand {
request,
exec_options: options.clone(),
start_time: None,
end_time: None,
return_value: None,
error: None,
});
}
all_bundles.push(bundle);
}
Ok(all_bundles)
}
/// Executes all command bundles concurrently using tokio::spawn and waits for results.
async fn execute_sequences_concurrent(
&self,
all_bundles: &mut Vec<ActionSequenceCommandBundle>,
) -> Result<(), MbtError> {
// Wrap the mutable bundles in an Arc<Mutex> for safe concurrent access
let shared_bundles = Arc::new(Mutex::new(all_bundles));
let mut futures = Vec::with_capacity(all_bundles.len());
for seq_idx in 0..all_bundles.len() {
let dispatcher_arc = self.dispatcher.clone(); // Clone dispatcher Arc for the task
let shared_bundles_clone = shared_bundles.clone();
let future = tokio::spawn(async move {
// CRITICAL: Lock the parent vector to get a mutable reference to the sequence
let mut bundles_guard = shared_bundles_clone.lock()
.map_err(|_| MbtError::other("Mutex lock poisoned"))?;
let sequence: &mut ActionSequenceCommandBundle = bundles_guard.get_mut(seq_idx)
.ok_or_else(|| MbtError::other("Sequence index out of bounds"))?;
// Run each action sequentially within this sequence
for cmd in sequence.iter_mut() {
// Short, non-blocking sleep to increase context switching
sleep(Duration::from_micros(1)).await;
// Lock the dispatcher for the action execution
let dispatcher_lock = dispatcher_arc.lock().map_err(|_| MbtError::other("Mutex lock poisoned"))?;
let start_time = Instant::now();
let action_name = cmd.request.action_name.clone();
let role_id = cmd.request.role_id.clone().unwrap_or_default();
// Simulate Action Execution
let (result, err) = match dispatcher_lock.execute_action(
&role_id,
&action_name,
&[], // Placeholder for arguments
) {
Ok(val) => (Some(val), None),
Err(e) => (None, Some(e)),
};
drop(dispatcher_lock); // Release lock immediately
let end_time = Instant::now();
// Record results
cmd.start_time = Some(start_time);
cmd.end_time = Some(end_time);
cmd.return_value = result;
cmd.error = err;
// Check for critical errors (excluding NotImplemented) and early exit.
if let Some(ref e) = cmd.error {
// If the error is *not* NotImplemented, we stop the sequence and report failure.
if !e.is_not_implemented() {
return Err(e.clone());
}
}
}
Ok(())
});
futures.push(future);
}
// --- REFACTOR: Use join_all for cleaner concurrent waiting ---
let results = join_all(futures).await;
// Process results
for res in results {
match res {
Ok(Ok(_)) => continue, // Sequence completed successfully (Inner and Outer Result OK)
Ok(Err(e)) => return Err(e), // Sequence failed with MbtError
Err(e) => return Err(MbtError::other(format!("Action sequence task panicked: {}", e))), // Task panicked
}
}
// --- END REFACTOR ---
Ok(())
}
/// Serializes the executed command bundles back into a proto response.
fn serialize_sequence_results(
&self,
all_bundles: Vec<ActionSequenceCommandBundle>,
) -> Result<Response<ExecuteActionSequencesResponse>, Status> {
let mut results = Vec::with_capacity(all_bundles.len());
for bundle in all_bundles {
let mut action_responses = Vec::with_capacity(bundle.len());
for cmd in bundle {
// Convert ExecutionCommand to ExecuteActionResponse
// 1. Time Interval
let exec_time = if let (Some(start), Some(end)) = (cmd.start_time, cmd.end_time) {
Some(Interval {
start_unix_nano: self.nanos_since_base(start),
end_unix_nano: self.nanos_since_base(end),
})
} else {
None
};
// 2. Return Values and Status
let (return_values, status) = match cmd.return_value {
Some(RustValue::None) => (vec![], proto_status_ok()),
Some(value) => {
let proto_value = rust_value_to_proto_value_internal(value);
(vec![proto_value], proto_status_ok())
}
None => (vec![], mbt_error_to_status(cmd.error.unwrap_or_default()).into_inner()), // Use error status
};
// 3. Roles and State (Defaults to empty as they are not tracked during execution)
let roles = vec![];
let role_states = vec![];
action_responses.push(ExecuteActionResponse {
return_values,
exec_time,
status: Some(status),
roles,
role_states,
});
}
results.push(ActionSequenceResult {
responses: action_responses,
});
}
let response = ExecuteActionSequencesResponse { results };
Ok(Response::new(response))
}
}
#[tonic::async_trait]
impl<D> FizzBeeMbtPluginService for FizzBeeServiceImpl<D>
where
D: Model + DispatchModel + Send + Sync + 'static,
{
async fn init(
&self,
_request: Request<InitRequest>,
) -> Result<Response<InitResponse>, Status> {
println!("Init request received.");
// 1. Lock the Model
let mut dispatcher = self.dispatcher.lock().await;
match dispatcher.init() { // <-- Calls init() on the generic M type
Ok(_) => {
let proto_role_refs = Self::get_and_convert_roles(&mut dispatcher)
.map_err(mbt_error_to_status)?; // Convert MbtError to gRPC Status
let response = InitResponse {
status: Some(proto_status_ok()),
roles: proto_role_refs,
// NOTE: You would typically retrieve and set the initial roles list here
..Default::default()
};
Ok(Response::new(response))
}
Err(e) => Err(mbt_error_to_status(e)),
}
}
async fn cleanup(
&self,
_request: Request<CleanupRequest>,
) -> Result<Response<CleanupResponse>, Status> {
println!("Cleanup request received.");
// 1. Lock the Model
let mut model = self.dispatcher.lock().await;
match model.cleanup() { // <-- Calls cleanup() on the generic M type
Ok(_) => {
let response = CleanupResponse {
status: Some(proto_status_ok()),
..Default::default()
};
Ok(Response::new(response))
}
Err(e) => Err(mbt_error_to_status(e)),
}
}
// --- execute_action will use self.dispatcher ---
async fn execute_action(
&self,
request: Request<ExecuteActionRequest>,
) -> Result<Response<ExecuteActionResponse>, Status> {
let req = request.into_inner();
// 1. Acquire the mutable lock on the dispatcher 🔒
let mut dispatcher = self.dispatcher.lock().await;
// 2. Convert Protobuf inputs to Rust types
let role_ref = req.role
.ok_or_else(|| Status::invalid_argument("RoleRef is missing in request."))?;
// Convert RoleRef to your internal RoleId structure
let role_id = proto_ref_to_role_id(role_ref)
.map_err(mbt_error_to_status)?;
let action_name = req.action_name;
// NOTE: We ignore `req.args` for now, but they would be converted here.
// 3. Delegate execution to the Dispatcher's method (D: DispatchModel)
let result = dispatcher.execute(
&role_id,
&action_name,
/* args */
);
// 4. Convert the result back to the Protobuf response 🔄
match result {
Ok(returned_value) => {
// Success: Map Rust Value to Protobuf Value
println!("DEBUG: Returned Rust Value: {:?}", returned_value);
let proto_role_refs = Self::get_and_convert_roles(&mut dispatcher)
.map_err(mbt_error_to_status)?; // Convert MbtError to gRPC Status
// 2. Check for empty value (RustValue::None) and determine the return vector.
let return_values = match returned_value {
// Case 1: Value is None (empty). Return an empty vector for the repeated field.
crate::value::Value::None => {
println!("DEBUG: Returned value is RustValue::None. Sending empty vector in response.");
vec![]
},
// Case 2: Value is present. Convert it and place it in the vector.
value => {
let proto_value = rust_value_to_proto_value(value); // Consumes the non-None value
// 3. Debug Print: Protobuf Value after conversion
println!("DEBUG: Converted Protobuf Value: {:?}", proto_value);
vec![proto_value]
}
};
let response = ExecuteActionResponse {
return_values: return_values, // Use the dynamically determined vector
status: Some(proto_status_ok()),
roles: proto_role_refs,
..Default::default()
};
Ok(Response::new(response))
}
Err(e) => {
// Failure: Return a successful gRPC status (HTTP 200) but an internal
// failure status in the response body.
let response = ExecuteActionResponse {
status: Some(mbt_error_to_proto_status(e)),
// return_values will be empty
..Default::default()
};
Ok(Response::new(response))
}
}
}
async fn execute_action_sequences(
&self,
request: Request<ExecuteActionSequencesRequest>,
) -> Result<Response<ExecuteActionSequencesResponse>, Status> {
// Step 1: Deserialize upfront
let mut all_bundles = match self.deserialize_sequences(request) {
Ok(bundles) => bundles,
Err(e) => return Err(mbt_error_to_status(e)),
};
// Step 2: Execute sequences concurrently
if let Err(e) = self.execute_sequences_concurrent(&mut all_bundles).await {
// If the error is NOT_IMPLEMENTED, the serialization step handles it.
// If it's any other error, we exit immediately.
if !e.is_not_implemented() {
return Err(mbt_error_to_status(e));
}
}
// Step 3: Serialize results
self.serialize_sequence_results(all_bundles)
.map_err(|e| mbt_error_to_status(e.into_inner()))
}
}