use std::collections::HashMap;
use async_trait::async_trait;
use oatf::enums::{ElicitationMode, LogLevel};
use oatf::primitives::{interpolate_template, interpolate_value};
use crate::error::EngineError;
#[async_trait]
pub trait EntryActionSender: Send + Sync {
async fn send_notification(
&self,
method: &str,
params: Option<&serde_json::Value>,
) -> Result<(), EngineError>;
async fn send_elicitation(
&self,
message: &str,
mode: Option<&ElicitationMode>,
requested_schema: Option<&serde_json::Value>,
url: Option<&str>,
elicitation_id: Option<&str>,
) -> Result<(), EngineError>;
}
#[allow(clippy::implicit_hasher, clippy::cognitive_complexity)]
pub async fn execute_entry_actions(
actions: &[oatf::Action],
extractors: &HashMap<String, String>,
sender: Option<&dyn EntryActionSender>,
) {
for action in actions {
match action {
oatf::Action::Send {
method,
params,
extensions: _,
} => {
let params = params
.as_ref()
.map(|p| interpolate_value(p, extractors, None, None).0);
if let Some(sender) = sender {
if let Err(err) = sender.send_notification(method, params.as_ref()).await {
tracing::warn!(
method,
error = %err,
"failed to send entry action"
);
}
} else {
tracing::warn!(method, "no entry action sender available — skipping send");
}
}
oatf::Action::Log {
message,
level,
extensions: _,
} => {
let (interpolated_message, _) =
interpolate_template(message, extractors, None, None);
match level {
Some(LogLevel::Error) => {
tracing::error!(source = "entry_action", "{}", interpolated_message);
}
Some(LogLevel::Warn) => {
tracing::warn!(source = "entry_action", "{}", interpolated_message);
}
Some(LogLevel::Info) | None => {
tracing::info!(source = "entry_action", "{}", interpolated_message);
}
}
}
oatf::Action::BindingSpecific {
key,
value: _,
extensions: _,
} => {
tracing::debug!(
action = key,
"binding-specific entry action — not handled by core engine"
);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use indexmap::IndexMap;
struct MockSender {
fail_notifications: bool,
calls: std::sync::Mutex<Vec<String>>,
}
impl MockSender {
fn succeeding() -> Self {
Self {
fail_notifications: false,
calls: std::sync::Mutex::new(Vec::new()),
}
}
fn failing_notifications() -> Self {
Self {
fail_notifications: true,
calls: std::sync::Mutex::new(Vec::new()),
}
}
fn call_log(&self) -> Vec<String> {
self.calls.lock().unwrap().clone()
}
}
#[async_trait]
impl EntryActionSender for MockSender {
async fn send_notification(
&self,
method: &str,
_params: Option<&serde_json::Value>,
) -> Result<(), EngineError> {
self.calls
.lock()
.unwrap()
.push(format!("notification:{method}"));
if self.fail_notifications {
Err(EngineError::EntryAction("transport closed".to_string()))
} else {
Ok(())
}
}
async fn send_elicitation(
&self,
_message: &str,
_mode: Option<&ElicitationMode>,
_requested_schema: Option<&serde_json::Value>,
_url: Option<&str>,
_elicitation_id: Option<&str>,
) -> Result<(), EngineError> {
Ok(())
}
}
#[tokio::test]
async fn log_action_executes_without_panic() {
let actions = vec![oatf::Action::Log {
message: "phase entered: {{phase_name}}".to_string(),
level: Some(LogLevel::Info),
extensions: IndexMap::new(),
}];
let mut extractors = HashMap::new();
extractors.insert("phase_name".to_string(), "exploit".to_string());
execute_entry_actions(&actions, &extractors, None).await;
}
#[tokio::test]
async fn send_without_sender_does_not_panic() {
let actions = vec![oatf::Action::Send {
method: "notifications/tools/list_changed".to_string(),
params: None,
extensions: IndexMap::new(),
}];
let extractors = HashMap::new();
execute_entry_actions(&actions, &extractors, None).await;
}
#[tokio::test]
async fn binding_specific_action_logged() {
let actions = vec![oatf::Action::BindingSpecific {
key: "send_request".to_string(),
value: serde_json::json!({"method": "custom"}),
extensions: IndexMap::new(),
}];
let extractors = HashMap::new();
execute_entry_actions(&actions, &extractors, None).await;
}
#[tokio::test]
async fn multiple_actions_execute_in_order() {
let actions = vec![
oatf::Action::Log {
message: "first".to_string(),
level: None,
extensions: IndexMap::new(),
},
oatf::Action::Log {
message: "second".to_string(),
level: Some(LogLevel::Warn),
extensions: IndexMap::new(),
},
oatf::Action::Log {
message: "third".to_string(),
level: Some(LogLevel::Error),
extensions: IndexMap::new(),
},
];
let extractors = HashMap::new();
execute_entry_actions(&actions, &extractors, None).await;
}
#[tokio::test]
async fn send_sender_error_continues_to_next_action() {
let sender = MockSender::failing_notifications();
let actions = vec![
oatf::Action::Send {
method: "notifications/tools/list_changed".to_string(),
params: None,
extensions: IndexMap::new(),
},
oatf::Action::Log {
message: "after send".to_string(),
level: None,
extensions: IndexMap::new(),
},
];
let extractors = HashMap::new();
execute_entry_actions(&actions, &extractors, Some(&sender)).await;
let calls = sender.call_log();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0], "notification:notifications/tools/list_changed");
}
#[tokio::test]
async fn send_with_interpolated_params() {
let sender = MockSender::succeeding();
let actions = vec![oatf::Action::Send {
method: "notifications/resources/updated".to_string(),
params: Some(serde_json::json!({"uri": "file:///{{path}}"})),
extensions: IndexMap::new(),
}];
let mut extractors = HashMap::new();
extractors.insert("path".to_string(), "secret.txt".to_string());
execute_entry_actions(&actions, &extractors, Some(&sender)).await;
let calls = sender.call_log();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0], "notification:notifications/resources/updated");
}
#[tokio::test]
async fn all_sender_errors_logged_but_execution_completes() {
let sender = MockSender {
fail_notifications: true,
calls: std::sync::Mutex::new(Vec::new()),
};
let actions = vec![
oatf::Action::Send {
method: "notify1".to_string(),
params: None,
extensions: IndexMap::new(),
},
oatf::Action::Send {
method: "notify2".to_string(),
params: None,
extensions: IndexMap::new(),
},
oatf::Action::Send {
method: "notify3".to_string(),
params: None,
extensions: IndexMap::new(),
},
];
let extractors = HashMap::new();
execute_entry_actions(&actions, &extractors, Some(&sender)).await;
let calls = sender.call_log();
assert_eq!(calls.len(), 3);
assert_eq!(calls[0], "notification:notify1");
assert_eq!(calls[1], "notification:notify2");
assert_eq!(calls[2], "notification:notify3");
}
}