Skip to main content

nemo_flow/api/
tool.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use serde_json::json;
5
6use crate::api::runtime::NemoFlowContextState;
7use crate::api::runtime::ToolExecutionNextFn;
8use crate::api::runtime::current_scope_stack;
9use crate::api::runtime::global_context;
10use crate::api::scope::event;
11use crate::api::scope::{EmitMarkEventParams, ScopeHandle};
12use crate::api::shared::{ensure_runtime_owner, resolve_parent_uuid, snapshot_event_subscribers};
13use crate::error::{FlowError, Result};
14use crate::json::Json;
15use bitflags::bitflags;
16use chrono::{DateTime, Utc};
17use serde::{Deserialize, Serialize};
18use typed_builder::TypedBuilder;
19use uuid::Uuid;
20
21bitflags! {
22    /// Bitflags that modify tool-call behavior and observability.
23    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24    pub struct ToolAttributes: u32 {
25        /// Marks the tool as executing out-of-process.
26        const REMOTE = 0b01;
27    }
28}
29
30/// Runtime-owned handle identifying an active or completed tool call.
31#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
32#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
33pub struct ToolHandle {
34    /// Unique tool-call identifier.
35    #[builder(default = Uuid::now_v7())]
36    pub uuid: Uuid,
37    /// Timestamp captured when the tool handle was created.
38    #[builder(default = Utc::now())]
39    pub started_at: DateTime<Utc>,
40    /// Tool name recorded on lifecycle events.
41    #[builder(setter(into))]
42    pub name: String,
43    /// Optional application payload stored on the handle.
44    #[builder(default)]
45    pub data: Option<Json>,
46    /// Optional metadata attached to the tool span.
47    #[builder(default)]
48    pub metadata: Option<Json>,
49    /// Tool behavior flags.
50    #[builder(default = ToolAttributes::empty())]
51    pub attributes: ToolAttributes,
52    /// UUID of the parent scope, if any.
53    #[builder(default)]
54    pub parent_uuid: Option<Uuid>,
55    /// Optional provider-specific tool-call correlation identifier.
56    #[builder(default, setter(into))]
57    pub tool_call_id: Option<String>,
58}
59
60/// Builder parameters for [`NemoFlowContextState::create_tool_handle`].
61#[derive(Debug, Clone, TypedBuilder)]
62#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
63pub struct CreateToolHandleParams<'a> {
64    /// Tool name recorded on emitted events.
65    pub name: &'a str,
66    /// Optional parent scope UUID.
67    #[builder(default)]
68    pub parent_uuid: Option<uuid::Uuid>,
69    /// Tool attribute bitflags.
70    #[builder(default = ToolAttributes::empty())]
71    pub attributes: ToolAttributes,
72    /// Optional application payload stored on the handle.
73    #[builder(default)]
74    pub data: Option<Json>,
75    /// Optional metadata stored on the handle.
76    #[builder(default)]
77    pub metadata: Option<Json>,
78    /// Optional provider-specific correlation identifier.
79    #[builder(default, setter(into))]
80    pub tool_call_id: Option<String>,
81    /// Optional timestamp captured as the handle start time and reused by the
82    /// emitted start event. When omitted, the current UTC time is used.
83    #[builder(default)]
84    pub timestamp: Option<DateTime<Utc>>,
85}
86
87/// Builder parameters for [`NemoFlowContextState::build_tool_end_event`].
88#[derive(Debug, Clone, TypedBuilder)]
89#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
90pub struct EndToolHandleParams<'a> {
91    /// Tool handle to serialize into the emitted end event.
92    pub handle: &'a ToolHandle,
93    /// Optional data payload merged over the handle data.
94    #[builder(default)]
95    pub data: Option<Json>,
96    /// Optional metadata payload merged over the handle metadata.
97    #[builder(default)]
98    pub metadata: Option<Json>,
99    /// Optional timestamp recorded on the emitted end event. When omitted, the
100    /// runtime records the current UTC time, or one microsecond after the
101    /// handle start time if the current time is not later.
102    #[builder(default)]
103    pub timestamp: Option<DateTime<Utc>>,
104}
105
106/// Builder parameters for [`tool_call`].
107#[derive(TypedBuilder)]
108#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
109pub struct ToolCallParams<'a> {
110    /// Tool name recorded on the emitted lifecycle event.
111    pub name: &'a str,
112    /// Raw tool arguments associated with the span.
113    pub args: Json,
114    /// Optional explicit parent scope.
115    #[builder(default)]
116    pub parent: Option<&'a ScopeHandle>,
117    /// Tool attribute bitflags applied to the span.
118    #[builder(default = ToolAttributes::empty())]
119    pub attributes: ToolAttributes,
120    /// Optional application payload stored on the handle but not emitted as
121    /// Agent Trajectory Observability Format (ATOF) data.
122    #[builder(default)]
123    pub data: Option<Json>,
124    /// Optional JSON metadata recorded on the start event.
125    #[builder(default)]
126    pub metadata: Option<Json>,
127    /// Optional provider-specific correlation identifier.
128    #[builder(default, setter(into))]
129    pub tool_call_id: Option<String>,
130    /// Optional timestamp captured as the handle start time and reused by the
131    /// emitted start event. When omitted, the current UTC time is used.
132    #[builder(default)]
133    pub timestamp: Option<DateTime<Utc>>,
134}
135
136/// Builder parameters for [`tool_call_execute`].
137#[derive(TypedBuilder)]
138#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
139pub struct ToolCallExecuteParams {
140    /// Tool name recorded on emitted lifecycle events.
141    #[builder(setter(into))]
142    pub name: String,
143    /// Raw tool arguments passed into the managed pipeline.
144    pub args: Json,
145    /// Tool callback or execution continuation.
146    pub func: ToolExecutionNextFn,
147    /// Optional explicit parent scope for the emitted tool span.
148    #[builder(default)]
149    pub parent: Option<ScopeHandle>,
150    /// Tool attribute bitflags applied to the managed span.
151    #[builder(default = ToolAttributes::empty())]
152    pub attributes: ToolAttributes,
153    /// Optional application payload stored on the handle but not emitted as
154    /// Agent Trajectory Observability Format (ATOF) data.
155    #[builder(default)]
156    pub data: Option<Json>,
157    /// Optional JSON metadata recorded on emitted events.
158    #[builder(default)]
159    pub metadata: Option<Json>,
160}
161
162/// Builder parameters for [`tool_call_end`].
163#[derive(TypedBuilder)]
164#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
165pub struct ToolCallEndParams<'a> {
166    /// Tool handle to close.
167    pub handle: &'a ToolHandle,
168    /// Raw tool result associated with the end event.
169    pub result: Json,
170    /// Optional application payload retained for compatibility; Agent
171    /// Trajectory Observability Format (ATOF) data is the result.
172    #[builder(default)]
173    pub data: Option<Json>,
174    /// Optional JSON metadata recorded on the end event.
175    #[builder(default)]
176    pub metadata: Option<Json>,
177    /// Optional timestamp recorded on the emitted end event. When omitted, the
178    /// runtime records the current UTC time, or one microsecond after the
179    /// handle start time if the current time is not later.
180    #[builder(default)]
181    pub timestamp: Option<DateTime<Utc>>,
182}
183
184/// Start a manual tool lifecycle span.
185///
186/// This emits a tool-start event after applying sanitize-request guardrails to
187/// the payload recorded for observability.
188///
189/// # Parameters
190/// - `name`: Tool name recorded on the emitted lifecycle event.
191/// - `args`: Raw tool arguments associated with the span.
192/// - `parent`: Optional explicit parent scope.
193/// - `attributes`: Tool attribute bitflags applied to the span.
194/// - `data`: Optional application payload stored on the returned handle. The
195///   emitted start event data is the sanitized `args` payload.
196/// - `metadata`: Optional JSON metadata recorded on the start event.
197/// - `tool_call_id`: Optional provider-specific correlation identifier.
198/// - `timestamp`: Optional timestamp recorded as the handle start time and on
199///   the emitted start event. When `None`, the current UTC time is used.
200///
201/// # Returns
202/// A [`Result`] containing the created [`ToolHandle`].
203///
204/// # Errors
205/// Returns an error when the runtime owner check fails or when internal state
206/// cannot be read safely.
207///
208/// # Notes
209/// Sanitize-request guardrails affect only the emitted start-event payload, not
210/// the caller-owned `args` value.
211pub fn tool_call(params: ToolCallParams<'_>) -> Result<ToolHandle> {
212    ensure_runtime_owner()?;
213    let parent_uuid = resolve_parent_uuid(params.parent);
214    let (handle, event, subscribers) = {
215        let scope_stack = current_scope_stack();
216        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
217        let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
218            &registries.tool_sanitize_request_guardrails
219        });
220        let scope_subscribers = scope_guard.collect_scope_local_subscribers();
221        let subscribers = snapshot_event_subscribers(scope_subscribers)?;
222        let context = global_context();
223        let state = context
224            .read()
225            .map_err(|error| FlowError::Internal(error.to_string()))?;
226
227        let sanitized_args =
228            state.tool_sanitize_request_chain(params.name, params.args, &scope_locals);
229        let handle_params = CreateToolHandleParams::builder()
230            .name(params.name)
231            .parent_uuid_opt(parent_uuid)
232            .attributes(params.attributes)
233            .data_opt(params.data)
234            .metadata_opt(params.metadata)
235            .tool_call_id_opt(params.tool_call_id)
236            .timestamp_opt(params.timestamp)
237            .build();
238        let handle = state.create_tool_handle(handle_params);
239        let event = state.build_tool_start_event(&handle, Some(sanitized_args));
240        (handle, event, subscribers)
241    };
242    NemoFlowContextState::emit_event(&event, &subscribers);
243    Ok(handle)
244}
245
246/// Finish a manual tool lifecycle span.
247///
248/// This emits a tool-end event for a handle previously returned by
249/// [`tool_call`].
250///
251/// # Parameters
252/// - `handle`: Tool handle to close.
253/// - `result`: Raw tool result associated with the end event.
254/// - `data`: Optional application payload retained for compatibility. The
255///   emitted end event data is the sanitized `result` unless it sanitizes to
256///   JSON null, in which case this payload is used.
257/// - `metadata`: Optional JSON metadata recorded on the end event.
258/// - `timestamp`: Optional timestamp recorded on the emitted end event. When
259///   `None`, the runtime uses the current UTC time, or one microsecond after
260///   the handle start time if the current time is not later.
261///
262/// # Returns
263/// A [`Result`] that is `Ok(())` when the end event has been emitted.
264///
265/// # Errors
266/// Returns an error when the runtime owner check fails or when internal state
267/// cannot be read safely.
268///
269/// # Notes
270/// Sanitize-response guardrails affect only the emitted end-event payload, not
271/// the caller-owned `result` value.
272pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> {
273    ensure_runtime_owner()?;
274    let (event, subscribers) = {
275        let scope_stack = current_scope_stack();
276        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
277        let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
278            &registries.tool_sanitize_response_guardrails
279        });
280        let scope_subscribers = scope_guard.collect_scope_local_subscribers();
281        let subscribers = snapshot_event_subscribers(scope_subscribers)?;
282        let context = global_context();
283        let state = context
284            .read()
285            .map_err(|error| FlowError::Internal(error.to_string()))?;
286
287        let sanitized_result =
288            state.tool_sanitize_response_chain(&params.handle.name, params.result, &scope_locals);
289        let data = if sanitized_result.is_null() {
290            params.data
291        } else {
292            Some(sanitized_result)
293        };
294        let event = state.build_tool_end_event(
295            EndToolHandleParams::builder()
296                .handle(params.handle)
297                .data_opt(data)
298                .metadata_opt(params.metadata)
299                .timestamp_opt(params.timestamp)
300                .build(),
301        );
302        (event, subscribers)
303    };
304    NemoFlowContextState::emit_event(&event, &subscribers);
305    Ok(())
306}
307
308fn emit_tool_end_without_output(handle: &ToolHandle, metadata: Option<Json>) -> Result<()> {
309    ensure_runtime_owner()?;
310    let (event, subscribers) = {
311        let scope_stack = current_scope_stack();
312        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
313        let scope_subscribers = scope_guard.collect_scope_local_subscribers();
314        let subscribers = snapshot_event_subscribers(scope_subscribers)?;
315        let context = global_context();
316        let state = context
317            .read()
318            .map_err(|error| FlowError::Internal(error.to_string()))?;
319        let event = state.end_tool_handle(handle, handle.data.clone(), metadata);
320        (event, subscribers)
321    };
322    NemoFlowContextState::emit_event(&event, &subscribers);
323    Ok(())
324}
325
326/// Execute a tool call through the managed middleware pipeline.
327///
328/// This runs conditional-execution guardrails, request intercepts,
329/// sanitize-request guardrails, execution intercepts, the tool callback, and
330/// sanitize-response guardrails in the runtime-defined order.
331///
332/// # Parameters
333/// - `name`: Tool name recorded on emitted lifecycle events.
334/// - `args`: Raw tool arguments passed into the managed pipeline.
335/// - `func`: Tool callback or execution continuation.
336/// - `parent`: Optional explicit parent scope for the emitted tool span.
337/// - `attributes`: Tool attribute bitflags applied to the managed span.
338/// - `data`: Optional application payload stored on the managed tool handle.
339///   It may be used on failure end events that have no output payload.
340/// - `metadata`: Optional JSON metadata recorded on emitted events.
341///
342/// # Returns
343/// A [`Result`] containing the raw tool result returned by the callback or an
344/// execution intercept.
345///
346/// # Errors
347/// Returns [`FlowError::GuardrailRejected`] when conditional-execution
348/// guardrails block the call, or any error raised by request intercepts,
349/// execution intercepts, or the callback itself.
350///
351/// # Notes
352/// When execution fails after the start event has been emitted, the runtime
353/// still emits a tool-end event without an output payload.
354pub async fn tool_call_execute(params: ToolCallExecuteParams) -> Result<Json> {
355    let ToolCallExecuteParams {
356        name,
357        args,
358        func,
359        parent,
360        attributes,
361        data,
362        metadata,
363    } = params;
364    ensure_runtime_owner()?;
365    {
366        let scope_stack = current_scope_stack();
367        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
368        let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
369            &registries.tool_conditional_execution_guardrails
370        });
371        let context = global_context();
372        let state = context
373            .read()
374            .map_err(|error| FlowError::Internal(error.to_string()))?;
375        if let Some(error) = state.tool_conditional_execution_chain(&name, &args, &scope_locals)? {
376            drop(state);
377            drop(scope_guard);
378            let mut rejection_data = json!({});
379            if let Some(object) = rejection_data.as_object_mut() {
380                object.insert("rejected".into(), json!(true));
381                object.insert("rejection_reason".into(), json!(&error));
382            }
383            let _ = event(
384                EmitMarkEventParams::builder()
385                    .name(&name)
386                    .parent_opt(parent.as_ref())
387                    .data(rejection_data)
388                    .metadata_opt(metadata.clone())
389                    .build(),
390            );
391            return Err(FlowError::GuardrailRejected(error));
392        }
393    }
394
395    let intercepted_args = {
396        let scope_stack = current_scope_stack();
397        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
398        let scope_locals = scope_guard
399            .collect_scope_local_registries(|registries| &registries.tool_request_intercepts);
400        let context = global_context();
401        let state = context
402            .read()
403            .map_err(|error| FlowError::Internal(error.to_string()))?;
404        state.tool_request_intercepts_chain(&name, args, &scope_locals)?
405    };
406
407    let handle = tool_call(
408        ToolCallParams::builder()
409            .name(name.as_str())
410            .args(intercepted_args.clone())
411            .parent_opt(parent.as_ref())
412            .attributes(attributes)
413            .data_opt(data.clone())
414            .metadata_opt(metadata.clone())
415            .build(),
416    )?;
417
418    let execution = {
419        let scope_stack = current_scope_stack();
420        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
421        let scope_locals = scope_guard
422            .collect_scope_local_registries(|registries| &registries.tool_execution_intercepts);
423        let context = global_context();
424        let state = context
425            .read()
426            .map_err(|error| FlowError::Internal(error.to_string()))?;
427        state.tool_build_execution_chain(&name, func, &scope_locals)
428    };
429
430    match execution(intercepted_args).await {
431        Ok(result) => {
432            tool_call_end(
433                ToolCallEndParams::builder()
434                    .handle(&handle)
435                    .result(result.clone())
436                    .data_opt(data)
437                    .metadata_opt(metadata)
438                    .build(),
439            )?;
440            Ok(result)
441        }
442        Err(error) => {
443            let _ = emit_tool_end_without_output(&handle, metadata);
444            Err(error)
445        }
446    }
447}
448
449/// Run only the tool request-intercept chain.
450///
451/// This applies the currently active global and scope-local request intercepts
452/// without emitting lifecycle events or invoking tool execution.
453///
454/// # Parameters
455/// - `name`: Tool name used when resolving the intercept chain.
456/// - `args`: Raw tool arguments to transform.
457///
458/// # Returns
459/// A [`Result`] containing the transformed JSON arguments.
460///
461/// # Errors
462/// Returns any error raised by the request-intercept chain.
463///
464/// # Notes
465/// Conditional guardrails and execution intercepts are not run by this helper.
466pub fn tool_request_intercepts(name: &str, args: Json) -> Result<Json> {
467    ensure_runtime_owner()?;
468    let scope_stack = current_scope_stack();
469    let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
470    let scope_locals = scope_guard
471        .collect_scope_local_registries(|registries| &registries.tool_request_intercepts);
472    let context = global_context();
473    let state = context
474        .read()
475        .map_err(|error| FlowError::Internal(error.to_string()))?;
476    state.tool_request_intercepts_chain(name, args, &scope_locals)
477}
478
479/// Run only the tool conditional-execution guardrail chain.
480///
481/// This evaluates whether a tool call should be allowed to proceed without
482/// emitting lifecycle events or invoking request intercepts or execution.
483///
484/// # Parameters
485/// - `name`: Tool name used when resolving the guardrail chain.
486/// - `args`: Raw tool arguments to validate.
487///
488/// # Returns
489/// A [`Result`] that is `Ok(())` when all guardrails allow execution.
490///
491/// # Errors
492/// Returns [`FlowError::GuardrailRejected`] when a guardrail blocks execution,
493/// or any error raised by the guardrail chain itself.
494///
495/// # Notes
496/// This helper is useful for preflight checks when the caller needs the
497/// rejection result without starting a tool span.
498pub fn tool_conditional_execution(name: &str, args: &Json) -> Result<()> {
499    ensure_runtime_owner()?;
500    let scope_stack = current_scope_stack();
501    let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
502    let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
503        &registries.tool_conditional_execution_guardrails
504    });
505    let context = global_context();
506    let state = context
507        .read()
508        .map_err(|error| FlowError::Internal(error.to_string()))?;
509    if let Some(error) = state.tool_conditional_execution_chain(name, args, &scope_locals)? {
510        return Err(FlowError::GuardrailRejected(error));
511    }
512    Ok(())
513}