use tonic::{Request, Response, Status};
use tokio::sync::Mutex;
use std::time::{Instant, Duration};
use tokio::time::sleep;
use futures::future::join_all; // To wait for concurrent tasks
use std::sync::Arc;
use crate::types::RoleId;
use crate::value::Value as RustValue;
use crate::value::sorted_map_entries;
use crate::pb::value::Kind;
use crate::pb::fizz_bee_mbt_plugin_service_server::FizzBeeMbtPluginService;
use crate::pb::{Value as ProtoValue, MapValue, MapEntry, ListValue};
use crate::pb::{
InitRequest, InitResponse,
CleanupRequest, CleanupResponse,
ExecuteActionRequest, ExecuteActionResponse,
ExecuteActionSequencesRequest, ExecuteActionSequencesResponse,
Interval, RoleRef, Status as ProtoStatus, StatusCode,
ActionSequence, ActionSequenceResult, ExecOptions,
};
use crate::traits::{Model, DispatchModel}; // <-- DEPENDS ONLY ON GENERIC L1 TRAITS
use crate::error::MbtError; // Your custom error type
// --- Helper functions (Assuming these are correct and handle the conversions) ---
fn mbt_error_to_status(err: MbtError) -> Status {
Status::internal(format!("MBT Execution Error: {}", err))
}
// 2. Convert Protobuf RoleRef to your Rust RoleId type.
fn proto_ref_to_role_id(proto_ref: RoleRef) -> Result<RoleId, MbtError> {
Ok(RoleId {
role_name: proto_ref.role_name,
index: proto_ref.role_id as i32
})
}
fn role_id_to_proto_ref(role_id: crate::types::RoleId) -> RoleRef {
RoleRef {
role_name: role_id.role_name,
// The proto expects i32, so we safely cast the u32 index.
role_id: role_id.index as i32,
}
}
/// Converts an internal Rust Value enum into a Protobuf Value message.
fn rust_value_to_proto_value(rust_value: RustValue) -> ProtoValue {
match rust_value {
RustValue::Int(v) => ProtoValue {
kind: Some(Kind::IntValue(v)),
},
RustValue::Str(s) => ProtoValue {
kind: Some(Kind::StrValue(s)),
},
RustValue::Bool(b) => ProtoValue {
kind: Some(Kind::BoolValue(b)),
},
// --- Compound Types ---
RustValue::Map(map) => {
let entries = sorted_map_entries(&map)
.into_iter()
.map(|(k, v)| MapEntry {
// Recursive call to convert key and value
key: Some(rust_value_to_proto_value(k.clone())),
value: Some(rust_value_to_proto_value(v.clone())),
})
.collect();
let map_value = MapValue { entries };
ProtoValue {
kind: Some(Kind::MapValue(map_value)),
}
}
RustValue::List(list) => {
let items = list
.into_iter()
.map(rust_value_to_proto_value) // Recursive call
.collect();
let list_value = ListValue { items };
ProtoValue {
kind: Some(Kind::ListValue(list_value)),
}
}
// --- Types Not Supported by Proto/Ignored ---
RustValue::Set(set) => {
// NOTE: Protobuf doesn't have a native Set type.
// We serialize it as a List, as the ordering doesn't matter for sets.
let items = set
.into_iter()
.map(rust_value_to_proto_value)
.collect();
let list_value = ListValue { items };
ProtoValue {
kind: Some(Kind::ListValue(list_value)),
}
}
RustValue::None => {
// NOTE: Protobuf oneof field requires one of the defined kinds.
// We can serialize `None` as an empty string or simply return the default/unspecified value.
// Returning default is often safest when a value is truly absent.
ProtoValue::default()
}
}
}
// --- Helper function to create a successful Proto Status ---
fn proto_status_ok() -> ProtoStatus {
ProtoStatus {
code: StatusCode::StatusOk as i32,
message: "OK".to_string(),
}
}
// 5. Convert MbtError into a Protobuf Status object (for the response body)
fn mbt_error_to_proto_status(err: MbtError) -> ProtoStatus {
// NOTE: In a real system, you'd map error types to specific StatusCodes
ProtoStatus {
code: StatusCode::StatusExecutionFailed as i32,
message: format!("Execution Failed: {}", err),
}
}
// --- Internal Command Structure ---
/// Internal structure representing a single action command with space for execution results.
#[derive(Debug, Clone)]
struct ExecuteActionCommand {
pub request: ExecuteActionRequest,
pub exec_options: ExecOptions,
// Results
pub start_time: Option<Instant>,
pub end_time: Option<Instant>,
pub return_value: Option<RustValue>,
pub error: Option<MbtError>,
}
/// Internal structure representing a sequence of actions.
type ActionSequenceCommandBundle = Vec<ExecuteActionCommand>;
// --- The Generic Service Implementation Struct ---
pub struct FizzBeeServiceImpl<D>
where
D: Model + DispatchModel + Send + Sync + 'static,
{
dispatcher: Arc<Mutex<D>>,
base_instant: Instant,
}
impl<D> FizzBeeServiceImpl<D>
where
D: Model + DispatchModel + Send + Sync + 'static,
{
pub fn new(dispatcher: D) -> Self {
FizzBeeServiceImpl {
dispatcher: Arc::new(Mutex::new(dispatcher)),
base_instant: Instant::now(),
}
}
/// Calls get_roles on the dispatcher and converts Rust RoleId structs to Protobuf RoleRef messages.
fn get_and_convert_roles(
dispatcher: &mut tokio::sync::MutexGuard<'_, D>
) -> Result<Vec<RoleRef>, MbtError> {
// 1. Call the Dispatcher's get_roles method
let rust_role_ids = dispatcher.get_roles()?;
// 2. Map Rust RoleId structs to Protobuf RoleRef messages
let proto_role_refs: Vec<RoleRef> = rust_role_ids.into_iter()
.map(role_id_to_proto_ref)
.collect();
Ok(proto_role_refs)
}
/// Calculates monotonic time in nanoseconds since the service was created.
fn nanos_since_base(&self, instant: Instant) -> i64 {
instant.duration_since(self.base_instant).as_nanos() as i64
}
// --- Helper for ExecuteActionSequences ---
/// Deserializes the proto request into internal command bundles.
fn deserialize_sequences(
req: Request<ExecuteActionSequencesRequest>
) -> Result<Vec<ActionSequenceCommandBundle>, MbtError> {
let proto_sequences = req.into_inner().action_sequence;
let mut all_bundles = Vec::with_capacity(proto_sequences.len());
for ActionSequence { requests, options } in proto_sequences {
let options = options.unwrap_or_default();
let mut bundle = Vec::with_capacity(requests.len());
for request in requests {
bundle.push(ExecuteActionCommand {
request,
exec_options: options.clone(),
start_time: None,
end_time: None,
return_value: None,
error: None,
});
}
all_bundles.push(bundle);
}
Ok(all_bundles)
}
/// Executes all command bundles concurrently using tokio::spawn and waits for results.
/// Each sequence is executed in a dedicated task, allowing their actions to interleave freely.
async fn execute_sequences_concurrent(
&self,
all_bundles: Vec<ActionSequenceCommandBundle>, // Take ownership (by value)
) -> Result<Vec<ActionSequenceCommandBundle>, MbtError> { // Return ownership on success
let mut futures = Vec::with_capacity(all_bundles.len());
// Iterate over the bundles, moving ownership and the original index into the future
for (seq_idx, mut sequence) in all_bundles.into_iter().enumerate() {
let dispatcher_arc = self.dispatcher.clone(); // Clone dispatcher Arc for the task
// The task will return a Result containing the original index and the executed bundle
let future = tokio::spawn(async move {
// Run each action sequentially within this task's owned sequence
for cmd in sequence.iter_mut() {
// Short, non-blocking sleep to encourage context switching and interleaving
sleep(Duration::from_micros(1)).await;
// Acquire the global lock on the model dispatcher D
let mut dispatcher_lock = dispatcher_arc.lock().await;
let start_time = Instant::now();
let action_name = cmd.request.action_name.clone();
// Convert Protobuf RoleRef to Rust RoleId struct.
let role_ref = cmd.request.role.clone().unwrap_or_default();
let role_id = proto_ref_to_role_id(role_ref)
.unwrap_or_else(|_| RoleId::default());
// Use `execute` method from the DispatchModel trait.
let (result, err) = match dispatcher_lock.execute(
&role_id, // &RoleId struct
&action_name,
// &[], // Placeholder for arguments
) {
Ok(val) => (Some(val), None),
Err(e) => (None, Some(e)),
};
drop(dispatcher_lock); // ***CRITICAL: Release lock immediately after execution***
// This allows other concurrent sequences to grab the lock and interleave.
let end_time = Instant::now();
// Record results locally in the owned sequence bundle
cmd.start_time = Some(start_time);
cmd.end_time = Some(end_time);
cmd.return_value = result;
cmd.error = err;
// Check for critical errors (excluding NotImplemented) and early exit.
if let Some(ref e) = cmd.error {
if !e.is_not_implemented() {
// Critical failure: return the MbtError
return Err(e.clone());
}
}
}
// On success, return the original index and the completed sequence bundle
Ok((seq_idx, sequence))
});
futures.push(future);
}
// --- Use join_all for cleaner concurrent waiting ---
let results = join_all(futures).await;
// FIX: Initialize the final result vector using iterator to avoid the Clone requirement.
let mut final_bundles: Vec<Option<ActionSequenceCommandBundle>> =
std::iter::repeat(None).take(results.len()).collect();
// Process results
for res in results {
match res {
// Task completed successfully (Outer Result Ok)
Ok(inner_res) => match inner_res {
// Execution succeeded (Inner Result Ok): place bundle at its original index
Ok((idx, sequence)) => {
final_bundles[idx] = Some(sequence);
},
// Execution failed with MbtError (Inner Result Err): critical error, return it immediately.
Err(e) => return Err(e),
},
// Task panicked (Outer Result Err): return a generalized MbtError
Err(e) => return Err(MbtError::other(format!("Action sequence task panicked: {}", e))),
}
}
// Unwrap the vector of Options into the final Vec<Bundle>, panic if one is missing (logic error)
let bundles = final_bundles.into_iter().map(|o| o.expect("Sequence missing after join_all. Logic error in indexing/processing.")).collect();
Ok(bundles) // Return the final vector
}
/// Serializes the executed command bundles back into a proto response.
fn serialize_sequence_results(
&self,
all_bundles: Vec<ActionSequenceCommandBundle>,
) -> Result<Response<ExecuteActionSequencesResponse>, Status> {
let mut results = Vec::with_capacity(all_bundles.len());
for bundle in all_bundles {
let mut action_responses = Vec::with_capacity(bundle.len());
for cmd in bundle {
// Convert ExecutionCommand to ExecuteActionResponse
// 1. Time Interval
let exec_time = if let (Some(start), Some(end)) = (cmd.start_time, cmd.end_time) {
Some(Interval {
start_unix_nano: self.nanos_since_base(start),
end_unix_nano: self.nanos_since_base(end),
})
} else {
None
};
// 2. Return Values and Status
let (return_values, status) = match cmd.return_value {
Some(RustValue::None) => (vec![], proto_status_ok()),
Some(value) => {
// Use the existing function `rust_value_to_proto_value`
let proto_value = rust_value_to_proto_value(value);
(vec![proto_value], proto_status_ok())
}
None => {
// Use mbt_error_to_proto_status
let error = cmd.error.unwrap_or_else(|| {
MbtError::other("Unknown error in sequence execution")
});
(vec![], mbt_error_to_proto_status(error))
}
};
// 3. Roles and State (Defaults to empty as they are not tracked during execution)
let roles = vec![];
let role_states = vec![];
action_responses.push(ExecuteActionResponse {
return_values,
exec_time,
status: Some(status),
roles,
role_states,
});
}
results.push(ActionSequenceResult {
responses: action_responses,
});
}
let response = ExecuteActionSequencesResponse { results };
Ok(Response::new(response))
}
}
#[tonic::async_trait]
impl<D> FizzBeeMbtPluginService for FizzBeeServiceImpl<D>
where
D: Model + DispatchModel + Send + Sync + 'static,
{
async fn init(
&self,
_request: Request<InitRequest>,
) -> Result<Response<InitResponse>, Status> {
println!("Init request received.");
// 1. Lock the Model
let mut dispatcher = self.dispatcher.lock().await;
match dispatcher.init() { // <-- Calls init() on the generic M type
Ok(_) => {
let proto_role_refs = Self::get_and_convert_roles(&mut dispatcher)
.map_err(mbt_error_to_status)?; // Convert MbtError to gRPC Status
let response = InitResponse {
status: Some(proto_status_ok()),
roles: proto_role_refs,
// NOTE: You would typically retrieve and set the initial roles list here
..Default::default()
};
Ok(Response::new(response))
}
Err(e) => Err(mbt_error_to_status(e)),
}
}
async fn cleanup(
&self,
_request: Request<CleanupRequest>,
) -> Result<Response<CleanupResponse>, Status> {
println!("Cleanup request received.");
// 1. Lock the Model
let mut model = self.dispatcher.lock().await;
match model.cleanup() { // <-- Calls cleanup() on the generic M type
Ok(_) => {
let response = CleanupResponse {
status: Some(proto_status_ok()),
..Default::default()
};
Ok(Response::new(response))
}
Err(e) => Err(mbt_error_to_status(e)),
}
}
// --- execute_action will use self.dispatcher ---
async fn execute_action(
&self,
request: Request<ExecuteActionRequest>,
) -> Result<Response<ExecuteActionResponse>, Status> {
let req = request.into_inner();
// 1. Acquire the mutable lock on the dispatcher 🔒
let mut dispatcher = self.dispatcher.lock().await;
// 2. Convert Protobuf inputs to Rust types
// The field is named `role` in the request, not `role_id`
let role_ref = req.role
.ok_or_else(|| Status::invalid_argument("RoleRef is missing in request."))?;
// Convert RoleRef to your internal RoleId structure
let role_id = proto_ref_to_role_id(role_ref)
.map_err(mbt_error_to_status)?;
let action_name = req.action_name;
// NOTE: We ignore `req.args` for now, but they would be converted here.
// 3. Delegate execution to the Dispatcher's method (D: DispatchModel)
let result = dispatcher.execute(
&role_id,
&action_name,
// &[], /* args */
);
// 4. Convert the result back to the Protobuf response 🔄
match result {
Ok(returned_value) => {
// Success: Map Rust Value to Protobuf Value
println!("DEBUG: Returned Rust Value: {:?}", returned_value);
let proto_role_refs = Self::get_and_convert_roles(&mut dispatcher)
.map_err(mbt_error_to_status)?; // Convert MbtError to gRPC Status
// 2. Check for empty value (RustValue::None) and determine the return vector.
let return_values = match returned_value {
// Case 1: Value is None (empty). Return an empty vector for the repeated field.
crate::value::Value::None => {
println!("DEBUG: Returned value is RustValue::None. Sending empty vector in response.");
vec![]
},
// Case 2: Value is present. Convert it and place it in the vector.
value => {
let proto_value = rust_value_to_proto_value(value); // Consumes the non-None value
// 3. Debug Print: Protobuf Value after conversion
println!("DEBUG: Converted Protobuf Value: {:?}", proto_value);
vec![proto_value]
}
};
let response = ExecuteActionResponse {
return_values: return_values, // Use the dynamically determined vector
status: Some(proto_status_ok()),
roles: proto_role_refs,
..Default::default()
};
Ok(Response::new(response))
}
Err(e) => {
// Failure: Return a successful gRPC status (HTTP 200) but an internal
// failure status in the response body.
let response = ExecuteActionResponse {
status: Some(mbt_error_to_proto_status(e)),
// return_values will be empty
..Default::default()
};
Ok(Response::new(response))
}
}
}
async fn execute_action_sequences(
&self,
request: Request<ExecuteActionSequencesRequest>,
) -> Result<Response<ExecuteActionSequencesResponse>, Status> {
// Step 1: Deserialize upfront
let all_bundles = match Self::deserialize_sequences(request) {
Ok(bundles) => bundles,
Err(e) => return Err(mbt_error_to_status(e)),
};
// Step 2: Execute sequences concurrently, taking and returning ownership
let all_bundles = self.execute_sequences_concurrent(all_bundles)
.await
.map_err(mbt_error_to_status)?;
// Step 3: Serialize results
self.serialize_sequence_results(all_bundles)
}
}