1use serde_json::json;
5
6use crate::api::runtime::NemoFlowContextState;
7use crate::api::runtime::ToolExecutionNextFn;
8use crate::api::runtime::current_scope_stack;
9use crate::api::runtime::global_context;
10use crate::api::scope::event;
11use crate::api::scope::{EmitMarkEventParams, ScopeHandle};
12use crate::api::shared::{ensure_runtime_owner, resolve_parent_uuid, snapshot_event_subscribers};
13use crate::error::{FlowError, Result};
14use crate::json::Json;
15use bitflags::bitflags;
16use chrono::{DateTime, Utc};
17use serde::{Deserialize, Serialize};
18use typed_builder::TypedBuilder;
19use uuid::Uuid;
20
21bitflags! {
22 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24 pub struct ToolAttributes: u32 {
25 const REMOTE = 0b01;
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
32#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
33pub struct ToolHandle {
34 #[builder(default = Uuid::now_v7())]
36 pub uuid: Uuid,
37 #[builder(default = Utc::now())]
39 pub started_at: DateTime<Utc>,
40 #[builder(setter(into))]
42 pub name: String,
43 #[builder(default)]
45 pub data: Option<Json>,
46 #[builder(default)]
48 pub metadata: Option<Json>,
49 #[builder(default = ToolAttributes::empty())]
51 pub attributes: ToolAttributes,
52 #[builder(default)]
54 pub parent_uuid: Option<Uuid>,
55 #[builder(default, setter(into))]
57 pub tool_call_id: Option<String>,
58}
59
60#[derive(Debug, Clone, TypedBuilder)]
62#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
63pub struct CreateToolHandleParams<'a> {
64 pub name: &'a str,
66 #[builder(default)]
68 pub parent_uuid: Option<uuid::Uuid>,
69 #[builder(default = ToolAttributes::empty())]
71 pub attributes: ToolAttributes,
72 #[builder(default)]
74 pub data: Option<Json>,
75 #[builder(default)]
77 pub metadata: Option<Json>,
78 #[builder(default, setter(into))]
80 pub tool_call_id: Option<String>,
81 #[builder(default)]
84 pub timestamp: Option<DateTime<Utc>>,
85}
86
87#[derive(Debug, Clone, TypedBuilder)]
89#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
90pub struct EndToolHandleParams<'a> {
91 pub handle: &'a ToolHandle,
93 #[builder(default)]
95 pub data: Option<Json>,
96 #[builder(default)]
98 pub metadata: Option<Json>,
99 #[builder(default)]
103 pub timestamp: Option<DateTime<Utc>>,
104}
105
106#[derive(TypedBuilder)]
108#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
109pub struct ToolCallParams<'a> {
110 pub name: &'a str,
112 pub args: Json,
114 #[builder(default)]
116 pub parent: Option<&'a ScopeHandle>,
117 #[builder(default = ToolAttributes::empty())]
119 pub attributes: ToolAttributes,
120 #[builder(default)]
122 pub data: Option<Json>,
123 #[builder(default)]
125 pub metadata: Option<Json>,
126 #[builder(default, setter(into))]
128 pub tool_call_id: Option<String>,
129 #[builder(default)]
132 pub timestamp: Option<DateTime<Utc>>,
133}
134
135#[derive(TypedBuilder)]
137#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
138pub struct ToolCallExecuteParams {
139 #[builder(setter(into))]
141 pub name: String,
142 pub args: Json,
144 pub func: ToolExecutionNextFn,
146 #[builder(default)]
148 pub parent: Option<ScopeHandle>,
149 #[builder(default = ToolAttributes::empty())]
151 pub attributes: ToolAttributes,
152 #[builder(default)]
154 pub data: Option<Json>,
155 #[builder(default)]
157 pub metadata: Option<Json>,
158}
159
160#[derive(TypedBuilder)]
162#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
163pub struct ToolCallEndParams<'a> {
164 pub handle: &'a ToolHandle,
166 pub result: Json,
168 #[builder(default)]
170 pub data: Option<Json>,
171 #[builder(default)]
173 pub metadata: Option<Json>,
174 #[builder(default)]
178 pub timestamp: Option<DateTime<Utc>>,
179}
180
181pub fn tool_call(params: ToolCallParams<'_>) -> Result<ToolHandle> {
209 ensure_runtime_owner()?;
210 let parent_uuid = resolve_parent_uuid(params.parent);
211 let (handle, event, subscribers) = {
212 let scope_stack = current_scope_stack();
213 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
214 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
215 ®istries.tool_sanitize_request_guardrails
216 });
217 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
218 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
219 let context = global_context();
220 let state = context
221 .read()
222 .map_err(|error| FlowError::Internal(error.to_string()))?;
223
224 let sanitized_args =
225 state.tool_sanitize_request_chain(params.name, params.args, &scope_locals);
226 let handle_params = CreateToolHandleParams::builder()
227 .name(params.name)
228 .parent_uuid_opt(parent_uuid)
229 .attributes(params.attributes)
230 .data_opt(params.data)
231 .metadata_opt(params.metadata)
232 .tool_call_id_opt(params.tool_call_id)
233 .timestamp_opt(params.timestamp)
234 .build();
235 let handle = state.create_tool_handle(handle_params);
236 let event = state.build_tool_start_event(&handle, Some(sanitized_args));
237 (handle, event, subscribers)
238 };
239 NemoFlowContextState::emit_event(&event, &subscribers);
240 Ok(handle)
241}
242
243pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> {
270 ensure_runtime_owner()?;
271 let (event, subscribers) = {
272 let scope_stack = current_scope_stack();
273 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
274 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
275 ®istries.tool_sanitize_response_guardrails
276 });
277 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
278 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
279 let context = global_context();
280 let state = context
281 .read()
282 .map_err(|error| FlowError::Internal(error.to_string()))?;
283
284 let sanitized_result =
285 state.tool_sanitize_response_chain(¶ms.handle.name, params.result, &scope_locals);
286 let data = if sanitized_result.is_null() {
287 params.data
288 } else {
289 Some(sanitized_result)
290 };
291 let event = state.build_tool_end_event(
292 EndToolHandleParams::builder()
293 .handle(params.handle)
294 .data_opt(data)
295 .metadata_opt(params.metadata)
296 .timestamp_opt(params.timestamp)
297 .build(),
298 );
299 (event, subscribers)
300 };
301 NemoFlowContextState::emit_event(&event, &subscribers);
302 Ok(())
303}
304
305fn emit_tool_end_without_output(handle: &ToolHandle, metadata: Option<Json>) -> Result<()> {
306 ensure_runtime_owner()?;
307 let (event, subscribers) = {
308 let scope_stack = current_scope_stack();
309 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
310 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
311 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
312 let context = global_context();
313 let state = context
314 .read()
315 .map_err(|error| FlowError::Internal(error.to_string()))?;
316 let event = state.end_tool_handle(handle, handle.data.clone(), metadata);
317 (event, subscribers)
318 };
319 NemoFlowContextState::emit_event(&event, &subscribers);
320 Ok(())
321}
322
323pub async fn tool_call_execute(params: ToolCallExecuteParams) -> Result<Json> {
352 let ToolCallExecuteParams {
353 name,
354 args,
355 func,
356 parent,
357 attributes,
358 data,
359 metadata,
360 } = params;
361 ensure_runtime_owner()?;
362 {
363 let scope_stack = current_scope_stack();
364 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
365 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
366 ®istries.tool_conditional_execution_guardrails
367 });
368 let context = global_context();
369 let state = context
370 .read()
371 .map_err(|error| FlowError::Internal(error.to_string()))?;
372 if let Some(error) = state.tool_conditional_execution_chain(&name, &args, &scope_locals)? {
373 drop(state);
374 drop(scope_guard);
375 let mut rejection_data = json!({});
376 if let Some(object) = rejection_data.as_object_mut() {
377 object.insert("rejected".into(), json!(true));
378 object.insert("rejection_reason".into(), json!(&error));
379 }
380 let _ = event(
381 EmitMarkEventParams::builder()
382 .name(&name)
383 .parent_opt(parent.as_ref())
384 .data(rejection_data)
385 .metadata_opt(metadata.clone())
386 .build(),
387 );
388 return Err(FlowError::GuardrailRejected(error));
389 }
390 }
391
392 let intercepted_args = {
393 let scope_stack = current_scope_stack();
394 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
395 let scope_locals = scope_guard
396 .collect_scope_local_registries(|registries| ®istries.tool_request_intercepts);
397 let context = global_context();
398 let state = context
399 .read()
400 .map_err(|error| FlowError::Internal(error.to_string()))?;
401 state.tool_request_intercepts_chain(&name, args, &scope_locals)?
402 };
403
404 let handle = tool_call(
405 ToolCallParams::builder()
406 .name(name.as_str())
407 .args(intercepted_args.clone())
408 .parent_opt(parent.as_ref())
409 .attributes(attributes)
410 .data_opt(data.clone())
411 .metadata_opt(metadata.clone())
412 .build(),
413 )?;
414
415 let execution = {
416 let scope_stack = current_scope_stack();
417 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
418 let scope_locals = scope_guard
419 .collect_scope_local_registries(|registries| ®istries.tool_execution_intercepts);
420 let context = global_context();
421 let state = context
422 .read()
423 .map_err(|error| FlowError::Internal(error.to_string()))?;
424 state.tool_build_execution_chain(&name, func, &scope_locals)
425 };
426
427 match execution(intercepted_args).await {
428 Ok(result) => {
429 tool_call_end(
430 ToolCallEndParams::builder()
431 .handle(&handle)
432 .result(result.clone())
433 .data_opt(data)
434 .metadata_opt(metadata)
435 .build(),
436 )?;
437 Ok(result)
438 }
439 Err(error) => {
440 let _ = emit_tool_end_without_output(&handle, metadata);
441 Err(error)
442 }
443 }
444}
445
446pub fn tool_request_intercepts(name: &str, args: Json) -> Result<Json> {
464 ensure_runtime_owner()?;
465 let scope_stack = current_scope_stack();
466 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
467 let scope_locals = scope_guard
468 .collect_scope_local_registries(|registries| ®istries.tool_request_intercepts);
469 let context = global_context();
470 let state = context
471 .read()
472 .map_err(|error| FlowError::Internal(error.to_string()))?;
473 state.tool_request_intercepts_chain(name, args, &scope_locals)
474}
475
476pub fn tool_conditional_execution(name: &str, args: &Json) -> Result<()> {
496 ensure_runtime_owner()?;
497 let scope_stack = current_scope_stack();
498 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
499 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
500 ®istries.tool_conditional_execution_guardrails
501 });
502 let context = global_context();
503 let state = context
504 .read()
505 .map_err(|error| FlowError::Internal(error.to_string()))?;
506 if let Some(error) = state.tool_conditional_execution_chain(name, args, &scope_locals)? {
507 return Err(FlowError::GuardrailRejected(error));
508 }
509 Ok(())
510}