use serde_json::json;
use crate::api::runtime::NemoFlowContextState;
use crate::api::runtime::ToolExecutionNextFn;
use crate::api::runtime::current_scope_stack;
use crate::api::runtime::global_context;
use crate::api::scope::event;
use crate::api::scope::{EmitMarkEventParams, ScopeHandle};
use crate::api::shared::{ensure_runtime_owner, resolve_parent_uuid, snapshot_event_subscribers};
use crate::error::{FlowError, Result};
use crate::json::Json;
use bitflags::bitflags;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use typed_builder::TypedBuilder;
use uuid::Uuid;
bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ToolAttributes: u32 {
const REMOTE = 0b01;
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct ToolHandle {
#[builder(default = Uuid::now_v7())]
pub uuid: Uuid,
#[builder(default = Utc::now())]
pub started_at: DateTime<Utc>,
#[builder(setter(into))]
pub name: String,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default = ToolAttributes::empty())]
pub attributes: ToolAttributes,
#[builder(default)]
pub parent_uuid: Option<Uuid>,
#[builder(default, setter(into))]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct CreateToolHandleParams<'a> {
pub name: &'a str,
#[builder(default)]
pub parent_uuid: Option<uuid::Uuid>,
#[builder(default = ToolAttributes::empty())]
pub attributes: ToolAttributes,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default, setter(into))]
pub tool_call_id: Option<String>,
#[builder(default)]
pub timestamp: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct EndToolHandleParams<'a> {
pub handle: &'a ToolHandle,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default)]
pub timestamp: Option<DateTime<Utc>>,
}
#[derive(TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct ToolCallParams<'a> {
pub name: &'a str,
pub args: Json,
#[builder(default)]
pub parent: Option<&'a ScopeHandle>,
#[builder(default = ToolAttributes::empty())]
pub attributes: ToolAttributes,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default, setter(into))]
pub tool_call_id: Option<String>,
#[builder(default)]
pub timestamp: Option<DateTime<Utc>>,
}
#[derive(TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct ToolCallExecuteParams {
#[builder(setter(into))]
pub name: String,
pub args: Json,
pub func: ToolExecutionNextFn,
#[builder(default)]
pub parent: Option<ScopeHandle>,
#[builder(default = ToolAttributes::empty())]
pub attributes: ToolAttributes,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
}
#[derive(TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct ToolCallEndParams<'a> {
pub handle: &'a ToolHandle,
pub result: Json,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default)]
pub timestamp: Option<DateTime<Utc>>,
}
pub fn tool_call(params: ToolCallParams<'_>) -> Result<ToolHandle> {
ensure_runtime_owner()?;
let parent_uuid = resolve_parent_uuid(params.parent);
let (handle, event, subscribers) = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
®istries.tool_sanitize_request_guardrails
});
let scope_subscribers = scope_guard.collect_scope_local_subscribers();
let subscribers = snapshot_event_subscribers(scope_subscribers)?;
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
let sanitized_args =
state.tool_sanitize_request_chain(params.name, params.args, &scope_locals);
let handle_params = CreateToolHandleParams::builder()
.name(params.name)
.parent_uuid_opt(parent_uuid)
.attributes(params.attributes)
.data_opt(params.data)
.metadata_opt(params.metadata)
.tool_call_id_opt(params.tool_call_id)
.timestamp_opt(params.timestamp)
.build();
let handle = state.create_tool_handle(handle_params);
let event = state.build_tool_start_event(&handle, Some(sanitized_args));
(handle, event, subscribers)
};
NemoFlowContextState::emit_event(&event, &subscribers);
Ok(handle)
}
pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> {
ensure_runtime_owner()?;
let (event, subscribers) = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
®istries.tool_sanitize_response_guardrails
});
let scope_subscribers = scope_guard.collect_scope_local_subscribers();
let subscribers = snapshot_event_subscribers(scope_subscribers)?;
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
let sanitized_result =
state.tool_sanitize_response_chain(¶ms.handle.name, params.result, &scope_locals);
let data = if sanitized_result.is_null() {
params.data
} else {
Some(sanitized_result)
};
let event = state.build_tool_end_event(
EndToolHandleParams::builder()
.handle(params.handle)
.data_opt(data)
.metadata_opt(params.metadata)
.timestamp_opt(params.timestamp)
.build(),
);
(event, subscribers)
};
NemoFlowContextState::emit_event(&event, &subscribers);
Ok(())
}
fn emit_tool_end_without_output(handle: &ToolHandle, metadata: Option<Json>) -> Result<()> {
ensure_runtime_owner()?;
let (event, subscribers) = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_subscribers = scope_guard.collect_scope_local_subscribers();
let subscribers = snapshot_event_subscribers(scope_subscribers)?;
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
let event = state.end_tool_handle(handle, handle.data.clone(), metadata);
(event, subscribers)
};
NemoFlowContextState::emit_event(&event, &subscribers);
Ok(())
}
pub async fn tool_call_execute(params: ToolCallExecuteParams) -> Result<Json> {
let ToolCallExecuteParams {
name,
args,
func,
parent,
attributes,
data,
metadata,
} = params;
ensure_runtime_owner()?;
{
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
®istries.tool_conditional_execution_guardrails
});
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
if let Some(error) = state.tool_conditional_execution_chain(&name, &args, &scope_locals)? {
drop(state);
drop(scope_guard);
let mut rejection_data = json!({});
if let Some(object) = rejection_data.as_object_mut() {
object.insert("rejected".into(), json!(true));
object.insert("rejection_reason".into(), json!(&error));
}
let _ = event(
EmitMarkEventParams::builder()
.name(&name)
.parent_opt(parent.as_ref())
.data(rejection_data)
.metadata_opt(metadata.clone())
.build(),
);
return Err(FlowError::GuardrailRejected(error));
}
}
let intercepted_args = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard
.collect_scope_local_registries(|registries| ®istries.tool_request_intercepts);
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
state.tool_request_intercepts_chain(&name, args, &scope_locals)?
};
let handle = tool_call(
ToolCallParams::builder()
.name(name.as_str())
.args(intercepted_args.clone())
.parent_opt(parent.as_ref())
.attributes(attributes)
.data_opt(data.clone())
.metadata_opt(metadata.clone())
.build(),
)?;
let execution = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard
.collect_scope_local_registries(|registries| ®istries.tool_execution_intercepts);
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
state.tool_build_execution_chain(&name, func, &scope_locals)
};
match execution(intercepted_args).await {
Ok(result) => {
tool_call_end(
ToolCallEndParams::builder()
.handle(&handle)
.result(result.clone())
.data_opt(data)
.metadata_opt(metadata)
.build(),
)?;
Ok(result)
}
Err(error) => {
let _ = emit_tool_end_without_output(&handle, metadata);
Err(error)
}
}
}
pub fn tool_request_intercepts(name: &str, args: Json) -> Result<Json> {
ensure_runtime_owner()?;
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard
.collect_scope_local_registries(|registries| ®istries.tool_request_intercepts);
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
state.tool_request_intercepts_chain(name, args, &scope_locals)
}
pub fn tool_conditional_execution(name: &str, args: &Json) -> Result<()> {
ensure_runtime_owner()?;
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
®istries.tool_conditional_execution_guardrails
});
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
if let Some(error) = state.tool_conditional_execution_chain(name, args, &scope_locals)? {
return Err(FlowError::GuardrailRejected(error));
}
Ok(())
}