Skip to main content

nemo_flow/api/
tool.rs

1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use serde_json::json;
5
6use crate::api::runtime::NemoFlowContextState;
7use crate::api::runtime::ToolExecutionNextFn;
8use crate::api::runtime::current_scope_stack;
9use crate::api::runtime::global_context;
10use crate::api::scope::event;
11use crate::api::scope::{EmitMarkEventParams, ScopeHandle};
12use crate::api::shared::{ensure_runtime_owner, resolve_parent_uuid, snapshot_event_subscribers};
13use crate::error::{FlowError, Result};
14use crate::json::Json;
15use bitflags::bitflags;
16use chrono::{DateTime, Utc};
17use serde::{Deserialize, Serialize};
18use typed_builder::TypedBuilder;
19use uuid::Uuid;
20
21bitflags! {
22    /// Bitflags that modify tool-call behavior and observability.
23    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24    pub struct ToolAttributes: u32 {
25        /// Marks the tool as executing out-of-process.
26        const REMOTE = 0b01;
27    }
28}
29
30/// Runtime-owned handle identifying an active or completed tool call.
31#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
32#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
33pub struct ToolHandle {
34    /// Unique tool-call identifier.
35    #[builder(default = Uuid::now_v7())]
36    pub uuid: Uuid,
37    /// Timestamp captured when the tool handle was created.
38    #[builder(default = Utc::now())]
39    pub started_at: DateTime<Utc>,
40    /// Tool name recorded on lifecycle events.
41    #[builder(setter(into))]
42    pub name: String,
43    /// Optional application payload stored on the handle.
44    #[builder(default)]
45    pub data: Option<Json>,
46    /// Optional metadata attached to the tool span.
47    #[builder(default)]
48    pub metadata: Option<Json>,
49    /// Tool behavior flags.
50    #[builder(default = ToolAttributes::empty())]
51    pub attributes: ToolAttributes,
52    /// UUID of the parent scope, if any.
53    #[builder(default)]
54    pub parent_uuid: Option<Uuid>,
55    /// Optional provider-specific tool-call correlation identifier.
56    #[builder(default, setter(into))]
57    pub tool_call_id: Option<String>,
58}
59
60/// Builder parameters for [`NemoFlowContextState::create_tool_handle`].
61#[derive(Debug, Clone, TypedBuilder)]
62#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
63pub struct CreateToolHandleParams<'a> {
64    /// Tool name recorded on emitted events.
65    pub name: &'a str,
66    /// Optional parent scope UUID.
67    #[builder(default)]
68    pub parent_uuid: Option<uuid::Uuid>,
69    /// Tool attribute bitflags.
70    #[builder(default = ToolAttributes::empty())]
71    pub attributes: ToolAttributes,
72    /// Optional application payload stored on the handle.
73    #[builder(default)]
74    pub data: Option<Json>,
75    /// Optional metadata stored on the handle.
76    #[builder(default)]
77    pub metadata: Option<Json>,
78    /// Optional provider-specific correlation identifier.
79    #[builder(default, setter(into))]
80    pub tool_call_id: Option<String>,
81    /// Optional timestamp captured as the handle start time and reused by the
82    /// emitted start event. When omitted, the current UTC time is used.
83    #[builder(default)]
84    pub timestamp: Option<DateTime<Utc>>,
85}
86
87/// Builder parameters for [`NemoFlowContextState::build_tool_end_event`].
88#[derive(Debug, Clone, TypedBuilder)]
89#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
90pub struct EndToolHandleParams<'a> {
91    /// Tool handle to serialize into the emitted end event.
92    pub handle: &'a ToolHandle,
93    /// Optional data payload merged over the handle data.
94    #[builder(default)]
95    pub data: Option<Json>,
96    /// Optional metadata payload merged over the handle metadata.
97    #[builder(default)]
98    pub metadata: Option<Json>,
99    /// Optional timestamp recorded on the emitted end event. When omitted, the
100    /// runtime records the current UTC time, or one microsecond after the
101    /// handle start time if the current time is not later.
102    #[builder(default)]
103    pub timestamp: Option<DateTime<Utc>>,
104}
105
106/// Builder parameters for [`tool_call`].
107#[derive(TypedBuilder)]
108#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
109pub struct ToolCallParams<'a> {
110    /// Tool name recorded on the emitted lifecycle event.
111    pub name: &'a str,
112    /// Raw tool arguments associated with the span.
113    pub args: Json,
114    /// Optional explicit parent scope.
115    #[builder(default)]
116    pub parent: Option<&'a ScopeHandle>,
117    /// Tool attribute bitflags applied to the span.
118    #[builder(default = ToolAttributes::empty())]
119    pub attributes: ToolAttributes,
120    /// Optional application payload stored on the handle but not emitted as ATOF data.
121    #[builder(default)]
122    pub data: Option<Json>,
123    /// Optional JSON metadata recorded on the start event.
124    #[builder(default)]
125    pub metadata: Option<Json>,
126    /// Optional provider-specific correlation identifier.
127    #[builder(default, setter(into))]
128    pub tool_call_id: Option<String>,
129    /// Optional timestamp captured as the handle start time and reused by the
130    /// emitted start event. When omitted, the current UTC time is used.
131    #[builder(default)]
132    pub timestamp: Option<DateTime<Utc>>,
133}
134
135/// Builder parameters for [`tool_call_execute`].
136#[derive(TypedBuilder)]
137#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
138pub struct ToolCallExecuteParams {
139    /// Tool name recorded on emitted lifecycle events.
140    #[builder(setter(into))]
141    pub name: String,
142    /// Raw tool arguments passed into the managed pipeline.
143    pub args: Json,
144    /// Tool callback or execution continuation.
145    pub func: ToolExecutionNextFn,
146    /// Optional explicit parent scope for the emitted tool span.
147    #[builder(default)]
148    pub parent: Option<ScopeHandle>,
149    /// Tool attribute bitflags applied to the managed span.
150    #[builder(default = ToolAttributes::empty())]
151    pub attributes: ToolAttributes,
152    /// Optional application payload stored on the handle but not emitted as ATOF data.
153    #[builder(default)]
154    pub data: Option<Json>,
155    /// Optional JSON metadata recorded on emitted events.
156    #[builder(default)]
157    pub metadata: Option<Json>,
158}
159
160/// Builder parameters for [`tool_call_end`].
161#[derive(TypedBuilder)]
162#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
163pub struct ToolCallEndParams<'a> {
164    /// Tool handle to close.
165    pub handle: &'a ToolHandle,
166    /// Raw tool result associated with the end event.
167    pub result: Json,
168    /// Optional application payload retained for compatibility; ATOF data is the result.
169    #[builder(default)]
170    pub data: Option<Json>,
171    /// Optional JSON metadata recorded on the end event.
172    #[builder(default)]
173    pub metadata: Option<Json>,
174    /// Optional timestamp recorded on the emitted end event. When omitted, the
175    /// runtime records the current UTC time, or one microsecond after the
176    /// handle start time if the current time is not later.
177    #[builder(default)]
178    pub timestamp: Option<DateTime<Utc>>,
179}
180
181/// Start a manual tool lifecycle span.
182///
183/// This emits a tool-start event after applying sanitize-request guardrails to
184/// the payload recorded for observability.
185///
186/// # Parameters
187/// - `name`: Tool name recorded on the emitted lifecycle event.
188/// - `args`: Raw tool arguments associated with the span.
189/// - `parent`: Optional explicit parent scope.
190/// - `attributes`: Tool attribute bitflags applied to the span.
191/// - `data`: Optional application payload stored on the returned handle. The
192///   emitted start event data is the sanitized `args` payload.
193/// - `metadata`: Optional JSON metadata recorded on the start event.
194/// - `tool_call_id`: Optional provider-specific correlation identifier.
195/// - `timestamp`: Optional timestamp recorded as the handle start time and on
196///   the emitted start event. When `None`, the current UTC time is used.
197///
198/// # Returns
199/// A [`Result`] containing the created [`ToolHandle`].
200///
201/// # Errors
202/// Returns an error when the runtime owner check fails or when internal state
203/// cannot be read safely.
204///
205/// # Notes
206/// Sanitize-request guardrails affect only the emitted start-event payload, not
207/// the caller-owned `args` value.
208pub fn tool_call(params: ToolCallParams<'_>) -> Result<ToolHandle> {
209    ensure_runtime_owner()?;
210    let parent_uuid = resolve_parent_uuid(params.parent);
211    let (handle, event, subscribers) = {
212        let scope_stack = current_scope_stack();
213        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
214        let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
215            &registries.tool_sanitize_request_guardrails
216        });
217        let scope_subscribers = scope_guard.collect_scope_local_subscribers();
218        let subscribers = snapshot_event_subscribers(scope_subscribers)?;
219        let context = global_context();
220        let state = context
221            .read()
222            .map_err(|error| FlowError::Internal(error.to_string()))?;
223
224        let sanitized_args =
225            state.tool_sanitize_request_chain(params.name, params.args, &scope_locals);
226        let handle_params = CreateToolHandleParams::builder()
227            .name(params.name)
228            .parent_uuid_opt(parent_uuid)
229            .attributes(params.attributes)
230            .data_opt(params.data)
231            .metadata_opt(params.metadata)
232            .tool_call_id_opt(params.tool_call_id)
233            .timestamp_opt(params.timestamp)
234            .build();
235        let handle = state.create_tool_handle(handle_params);
236        let event = state.build_tool_start_event(&handle, Some(sanitized_args));
237        (handle, event, subscribers)
238    };
239    NemoFlowContextState::emit_event(&event, &subscribers);
240    Ok(handle)
241}
242
243/// Finish a manual tool lifecycle span.
244///
245/// This emits a tool-end event for a handle previously returned by
246/// [`tool_call`].
247///
248/// # Parameters
249/// - `handle`: Tool handle to close.
250/// - `result`: Raw tool result associated with the end event.
251/// - `data`: Optional application payload retained for compatibility. The
252///   emitted end event data is the sanitized `result` unless it sanitizes to
253///   JSON null, in which case this payload is used.
254/// - `metadata`: Optional JSON metadata recorded on the end event.
255/// - `timestamp`: Optional timestamp recorded on the emitted end event. When
256///   `None`, the runtime uses the current UTC time, or one microsecond after
257///   the handle start time if the current time is not later.
258///
259/// # Returns
260/// A [`Result`] that is `Ok(())` when the end event has been emitted.
261///
262/// # Errors
263/// Returns an error when the runtime owner check fails or when internal state
264/// cannot be read safely.
265///
266/// # Notes
267/// Sanitize-response guardrails affect only the emitted end-event payload, not
268/// the caller-owned `result` value.
269pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> {
270    ensure_runtime_owner()?;
271    let (event, subscribers) = {
272        let scope_stack = current_scope_stack();
273        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
274        let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
275            &registries.tool_sanitize_response_guardrails
276        });
277        let scope_subscribers = scope_guard.collect_scope_local_subscribers();
278        let subscribers = snapshot_event_subscribers(scope_subscribers)?;
279        let context = global_context();
280        let state = context
281            .read()
282            .map_err(|error| FlowError::Internal(error.to_string()))?;
283
284        let sanitized_result =
285            state.tool_sanitize_response_chain(&params.handle.name, params.result, &scope_locals);
286        let data = if sanitized_result.is_null() {
287            params.data
288        } else {
289            Some(sanitized_result)
290        };
291        let event = state.build_tool_end_event(
292            EndToolHandleParams::builder()
293                .handle(params.handle)
294                .data_opt(data)
295                .metadata_opt(params.metadata)
296                .timestamp_opt(params.timestamp)
297                .build(),
298        );
299        (event, subscribers)
300    };
301    NemoFlowContextState::emit_event(&event, &subscribers);
302    Ok(())
303}
304
305fn emit_tool_end_without_output(handle: &ToolHandle, metadata: Option<Json>) -> Result<()> {
306    ensure_runtime_owner()?;
307    let (event, subscribers) = {
308        let scope_stack = current_scope_stack();
309        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
310        let scope_subscribers = scope_guard.collect_scope_local_subscribers();
311        let subscribers = snapshot_event_subscribers(scope_subscribers)?;
312        let context = global_context();
313        let state = context
314            .read()
315            .map_err(|error| FlowError::Internal(error.to_string()))?;
316        let event = state.end_tool_handle(handle, handle.data.clone(), metadata);
317        (event, subscribers)
318    };
319    NemoFlowContextState::emit_event(&event, &subscribers);
320    Ok(())
321}
322
323/// Execute a tool call through the managed middleware pipeline.
324///
325/// This runs conditional-execution guardrails, request intercepts,
326/// sanitize-request guardrails, execution intercepts, the tool callback, and
327/// sanitize-response guardrails in the runtime-defined order.
328///
329/// # Parameters
330/// - `name`: Tool name recorded on emitted lifecycle events.
331/// - `args`: Raw tool arguments passed into the managed pipeline.
332/// - `func`: Tool callback or execution continuation.
333/// - `parent`: Optional explicit parent scope for the emitted tool span.
334/// - `attributes`: Tool attribute bitflags applied to the managed span.
335/// - `data`: Optional application payload stored on the managed tool handle.
336///   It may be used on failure end events that have no output payload.
337/// - `metadata`: Optional JSON metadata recorded on emitted events.
338///
339/// # Returns
340/// A [`Result`] containing the raw tool result returned by the callback or an
341/// execution intercept.
342///
343/// # Errors
344/// Returns [`FlowError::GuardrailRejected`] when conditional-execution
345/// guardrails block the call, or any error raised by request intercepts,
346/// execution intercepts, or the callback itself.
347///
348/// # Notes
349/// When execution fails after the start event has been emitted, the runtime
350/// still emits a tool-end event without an output payload.
351pub async fn tool_call_execute(params: ToolCallExecuteParams) -> Result<Json> {
352    let ToolCallExecuteParams {
353        name,
354        args,
355        func,
356        parent,
357        attributes,
358        data,
359        metadata,
360    } = params;
361    ensure_runtime_owner()?;
362    {
363        let scope_stack = current_scope_stack();
364        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
365        let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
366            &registries.tool_conditional_execution_guardrails
367        });
368        let context = global_context();
369        let state = context
370            .read()
371            .map_err(|error| FlowError::Internal(error.to_string()))?;
372        if let Some(error) = state.tool_conditional_execution_chain(&name, &args, &scope_locals)? {
373            drop(state);
374            drop(scope_guard);
375            let mut rejection_data = json!({});
376            if let Some(object) = rejection_data.as_object_mut() {
377                object.insert("rejected".into(), json!(true));
378                object.insert("rejection_reason".into(), json!(&error));
379            }
380            let _ = event(
381                EmitMarkEventParams::builder()
382                    .name(&name)
383                    .parent_opt(parent.as_ref())
384                    .data(rejection_data)
385                    .metadata_opt(metadata.clone())
386                    .build(),
387            );
388            return Err(FlowError::GuardrailRejected(error));
389        }
390    }
391
392    let intercepted_args = {
393        let scope_stack = current_scope_stack();
394        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
395        let scope_locals = scope_guard
396            .collect_scope_local_registries(|registries| &registries.tool_request_intercepts);
397        let context = global_context();
398        let state = context
399            .read()
400            .map_err(|error| FlowError::Internal(error.to_string()))?;
401        state.tool_request_intercepts_chain(&name, args, &scope_locals)?
402    };
403
404    let handle = tool_call(
405        ToolCallParams::builder()
406            .name(name.as_str())
407            .args(intercepted_args.clone())
408            .parent_opt(parent.as_ref())
409            .attributes(attributes)
410            .data_opt(data.clone())
411            .metadata_opt(metadata.clone())
412            .build(),
413    )?;
414
415    let execution = {
416        let scope_stack = current_scope_stack();
417        let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
418        let scope_locals = scope_guard
419            .collect_scope_local_registries(|registries| &registries.tool_execution_intercepts);
420        let context = global_context();
421        let state = context
422            .read()
423            .map_err(|error| FlowError::Internal(error.to_string()))?;
424        state.tool_build_execution_chain(&name, func, &scope_locals)
425    };
426
427    match execution(intercepted_args).await {
428        Ok(result) => {
429            tool_call_end(
430                ToolCallEndParams::builder()
431                    .handle(&handle)
432                    .result(result.clone())
433                    .data_opt(data)
434                    .metadata_opt(metadata)
435                    .build(),
436            )?;
437            Ok(result)
438        }
439        Err(error) => {
440            let _ = emit_tool_end_without_output(&handle, metadata);
441            Err(error)
442        }
443    }
444}
445
446/// Run only the tool request-intercept chain.
447///
448/// This applies the currently active global and scope-local request intercepts
449/// without emitting lifecycle events or invoking tool execution.
450///
451/// # Parameters
452/// - `name`: Tool name used when resolving the intercept chain.
453/// - `args`: Raw tool arguments to transform.
454///
455/// # Returns
456/// A [`Result`] containing the transformed JSON arguments.
457///
458/// # Errors
459/// Returns any error raised by the request-intercept chain.
460///
461/// # Notes
462/// Conditional guardrails and execution intercepts are not run by this helper.
463pub fn tool_request_intercepts(name: &str, args: Json) -> Result<Json> {
464    ensure_runtime_owner()?;
465    let scope_stack = current_scope_stack();
466    let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
467    let scope_locals = scope_guard
468        .collect_scope_local_registries(|registries| &registries.tool_request_intercepts);
469    let context = global_context();
470    let state = context
471        .read()
472        .map_err(|error| FlowError::Internal(error.to_string()))?;
473    state.tool_request_intercepts_chain(name, args, &scope_locals)
474}
475
476/// Run only the tool conditional-execution guardrail chain.
477///
478/// This evaluates whether a tool call should be allowed to proceed without
479/// emitting lifecycle events or invoking request intercepts or execution.
480///
481/// # Parameters
482/// - `name`: Tool name used when resolving the guardrail chain.
483/// - `args`: Raw tool arguments to validate.
484///
485/// # Returns
486/// A [`Result`] that is `Ok(())` when all guardrails allow execution.
487///
488/// # Errors
489/// Returns [`FlowError::GuardrailRejected`] when a guardrail blocks execution,
490/// or any error raised by the guardrail chain itself.
491///
492/// # Notes
493/// This helper is useful for preflight checks when the caller needs the
494/// rejection result without starting a tool span.
495pub fn tool_conditional_execution(name: &str, args: &Json) -> Result<()> {
496    ensure_runtime_owner()?;
497    let scope_stack = current_scope_stack();
498    let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
499    let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
500        &registries.tool_conditional_execution_guardrails
501    });
502    let context = global_context();
503    let state = context
504        .read()
505        .map_err(|error| FlowError::Internal(error.to_string()))?;
506    if let Some(error) = state.tool_conditional_execution_chain(name, args, &scope_locals)? {
507        return Err(FlowError::GuardrailRejected(error));
508    }
509    Ok(())
510}