1use serde_json::json;
5
6use crate::api::runtime::NemoFlowContextState;
7use crate::api::runtime::ToolExecutionNextFn;
8use crate::api::runtime::current_scope_stack;
9use crate::api::runtime::global_context;
10use crate::api::scope::event;
11use crate::api::scope::{EmitMarkEventParams, ScopeHandle};
12use crate::api::shared::{ensure_runtime_owner, resolve_parent_uuid, snapshot_event_subscribers};
13use crate::error::{FlowError, Result};
14use crate::json::Json;
15use bitflags::bitflags;
16use chrono::{DateTime, Utc};
17use serde::{Deserialize, Serialize};
18use typed_builder::TypedBuilder;
19use uuid::Uuid;
20
21bitflags! {
22 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
24 pub struct ToolAttributes: u32 {
25 const REMOTE = 0b01;
27 }
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
32#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
33pub struct ToolHandle {
34 #[builder(default = Uuid::now_v7())]
36 pub uuid: Uuid,
37 #[builder(default = Utc::now())]
39 pub started_at: DateTime<Utc>,
40 #[builder(setter(into))]
42 pub name: String,
43 #[builder(default)]
45 pub data: Option<Json>,
46 #[builder(default)]
48 pub metadata: Option<Json>,
49 #[builder(default = ToolAttributes::empty())]
51 pub attributes: ToolAttributes,
52 #[builder(default)]
54 pub parent_uuid: Option<Uuid>,
55 #[builder(default, setter(into))]
57 pub tool_call_id: Option<String>,
58}
59
60#[derive(Debug, Clone, TypedBuilder)]
62#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
63pub struct CreateToolHandleParams<'a> {
64 pub name: &'a str,
66 #[builder(default)]
68 pub parent_uuid: Option<uuid::Uuid>,
69 #[builder(default = ToolAttributes::empty())]
71 pub attributes: ToolAttributes,
72 #[builder(default)]
74 pub data: Option<Json>,
75 #[builder(default)]
77 pub metadata: Option<Json>,
78 #[builder(default, setter(into))]
80 pub tool_call_id: Option<String>,
81 #[builder(default)]
84 pub timestamp: Option<DateTime<Utc>>,
85}
86
87#[derive(Debug, Clone, TypedBuilder)]
89#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
90pub struct EndToolHandleParams<'a> {
91 pub handle: &'a ToolHandle,
93 #[builder(default)]
95 pub data: Option<Json>,
96 #[builder(default)]
98 pub metadata: Option<Json>,
99 #[builder(default)]
103 pub timestamp: Option<DateTime<Utc>>,
104}
105
106#[derive(TypedBuilder)]
108#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
109pub struct ToolCallParams<'a> {
110 pub name: &'a str,
112 pub args: Json,
114 #[builder(default)]
116 pub parent: Option<&'a ScopeHandle>,
117 #[builder(default = ToolAttributes::empty())]
119 pub attributes: ToolAttributes,
120 #[builder(default)]
123 pub data: Option<Json>,
124 #[builder(default)]
126 pub metadata: Option<Json>,
127 #[builder(default, setter(into))]
129 pub tool_call_id: Option<String>,
130 #[builder(default)]
133 pub timestamp: Option<DateTime<Utc>>,
134}
135
136#[derive(TypedBuilder)]
138#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
139pub struct ToolCallExecuteParams {
140 #[builder(setter(into))]
142 pub name: String,
143 pub args: Json,
145 pub func: ToolExecutionNextFn,
147 #[builder(default)]
149 pub parent: Option<ScopeHandle>,
150 #[builder(default = ToolAttributes::empty())]
152 pub attributes: ToolAttributes,
153 #[builder(default)]
156 pub data: Option<Json>,
157 #[builder(default)]
159 pub metadata: Option<Json>,
160}
161
162#[derive(TypedBuilder)]
164#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
165pub struct ToolCallEndParams<'a> {
166 pub handle: &'a ToolHandle,
168 pub result: Json,
170 #[builder(default)]
173 pub data: Option<Json>,
174 #[builder(default)]
176 pub metadata: Option<Json>,
177 #[builder(default)]
181 pub timestamp: Option<DateTime<Utc>>,
182}
183
184pub fn tool_call(params: ToolCallParams<'_>) -> Result<ToolHandle> {
212 ensure_runtime_owner()?;
213 let parent_uuid = resolve_parent_uuid(params.parent);
214 let (handle, event, subscribers) = {
215 let scope_stack = current_scope_stack();
216 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
217 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
218 ®istries.tool_sanitize_request_guardrails
219 });
220 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
221 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
222 let context = global_context();
223 let state = context
224 .read()
225 .map_err(|error| FlowError::Internal(error.to_string()))?;
226
227 let sanitized_args =
228 state.tool_sanitize_request_chain(params.name, params.args, &scope_locals);
229 let handle_params = CreateToolHandleParams::builder()
230 .name(params.name)
231 .parent_uuid_opt(parent_uuid)
232 .attributes(params.attributes)
233 .data_opt(params.data)
234 .metadata_opt(params.metadata)
235 .tool_call_id_opt(params.tool_call_id)
236 .timestamp_opt(params.timestamp)
237 .build();
238 let handle = state.create_tool_handle(handle_params);
239 let event = state.build_tool_start_event(&handle, Some(sanitized_args));
240 (handle, event, subscribers)
241 };
242 NemoFlowContextState::emit_event(&event, &subscribers);
243 Ok(handle)
244}
245
246pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> {
273 ensure_runtime_owner()?;
274 let (event, subscribers) = {
275 let scope_stack = current_scope_stack();
276 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
277 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
278 ®istries.tool_sanitize_response_guardrails
279 });
280 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
281 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
282 let context = global_context();
283 let state = context
284 .read()
285 .map_err(|error| FlowError::Internal(error.to_string()))?;
286
287 let sanitized_result =
288 state.tool_sanitize_response_chain(¶ms.handle.name, params.result, &scope_locals);
289 let data = if sanitized_result.is_null() {
290 params.data
291 } else {
292 Some(sanitized_result)
293 };
294 let event = state.build_tool_end_event(
295 EndToolHandleParams::builder()
296 .handle(params.handle)
297 .data_opt(data)
298 .metadata_opt(params.metadata)
299 .timestamp_opt(params.timestamp)
300 .build(),
301 );
302 (event, subscribers)
303 };
304 NemoFlowContextState::emit_event(&event, &subscribers);
305 Ok(())
306}
307
308fn emit_tool_end_without_output(handle: &ToolHandle, metadata: Option<Json>) -> Result<()> {
309 ensure_runtime_owner()?;
310 let (event, subscribers) = {
311 let scope_stack = current_scope_stack();
312 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
313 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
314 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
315 let context = global_context();
316 let state = context
317 .read()
318 .map_err(|error| FlowError::Internal(error.to_string()))?;
319 let event = state.end_tool_handle(handle, handle.data.clone(), metadata);
320 (event, subscribers)
321 };
322 NemoFlowContextState::emit_event(&event, &subscribers);
323 Ok(())
324}
325
326pub async fn tool_call_execute(params: ToolCallExecuteParams) -> Result<Json> {
355 let ToolCallExecuteParams {
356 name,
357 args,
358 func,
359 parent,
360 attributes,
361 data,
362 metadata,
363 } = params;
364 ensure_runtime_owner()?;
365 {
366 let scope_stack = current_scope_stack();
367 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
368 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
369 ®istries.tool_conditional_execution_guardrails
370 });
371 let context = global_context();
372 let state = context
373 .read()
374 .map_err(|error| FlowError::Internal(error.to_string()))?;
375 if let Some(error) = state.tool_conditional_execution_chain(&name, &args, &scope_locals)? {
376 drop(state);
377 drop(scope_guard);
378 let mut rejection_data = json!({});
379 if let Some(object) = rejection_data.as_object_mut() {
380 object.insert("rejected".into(), json!(true));
381 object.insert("rejection_reason".into(), json!(&error));
382 }
383 let _ = event(
384 EmitMarkEventParams::builder()
385 .name(&name)
386 .parent_opt(parent.as_ref())
387 .data(rejection_data)
388 .metadata_opt(metadata.clone())
389 .build(),
390 );
391 return Err(FlowError::GuardrailRejected(error));
392 }
393 }
394
395 let intercepted_args = {
396 let scope_stack = current_scope_stack();
397 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
398 let scope_locals = scope_guard
399 .collect_scope_local_registries(|registries| ®istries.tool_request_intercepts);
400 let context = global_context();
401 let state = context
402 .read()
403 .map_err(|error| FlowError::Internal(error.to_string()))?;
404 state.tool_request_intercepts_chain(&name, args, &scope_locals)?
405 };
406
407 let handle = tool_call(
408 ToolCallParams::builder()
409 .name(name.as_str())
410 .args(intercepted_args.clone())
411 .parent_opt(parent.as_ref())
412 .attributes(attributes)
413 .data_opt(data.clone())
414 .metadata_opt(metadata.clone())
415 .build(),
416 )?;
417
418 let execution = {
419 let scope_stack = current_scope_stack();
420 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
421 let scope_locals = scope_guard
422 .collect_scope_local_registries(|registries| ®istries.tool_execution_intercepts);
423 let context = global_context();
424 let state = context
425 .read()
426 .map_err(|error| FlowError::Internal(error.to_string()))?;
427 state.tool_build_execution_chain(&name, func, &scope_locals)
428 };
429
430 match execution(intercepted_args).await {
431 Ok(result) => {
432 tool_call_end(
433 ToolCallEndParams::builder()
434 .handle(&handle)
435 .result(result.clone())
436 .data_opt(data)
437 .metadata_opt(metadata)
438 .build(),
439 )?;
440 Ok(result)
441 }
442 Err(error) => {
443 let _ = emit_tool_end_without_output(&handle, metadata);
444 Err(error)
445 }
446 }
447}
448
449pub fn tool_request_intercepts(name: &str, args: Json) -> Result<Json> {
467 ensure_runtime_owner()?;
468 let scope_stack = current_scope_stack();
469 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
470 let scope_locals = scope_guard
471 .collect_scope_local_registries(|registries| ®istries.tool_request_intercepts);
472 let context = global_context();
473 let state = context
474 .read()
475 .map_err(|error| FlowError::Internal(error.to_string()))?;
476 state.tool_request_intercepts_chain(name, args, &scope_locals)
477}
478
479pub fn tool_conditional_execution(name: &str, args: &Json) -> Result<()> {
499 ensure_runtime_owner()?;
500 let scope_stack = current_scope_stack();
501 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
502 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
503 ®istries.tool_conditional_execution_guardrails
504 });
505 let context = global_context();
506 let state = context
507 .read()
508 .map_err(|error| FlowError::Internal(error.to_string()))?;
509 if let Some(error) = state.tool_conditional_execution_chain(name, args, &scope_locals)? {
510 return Err(FlowError::GuardrailRejected(error));
511 }
512 Ok(())
513}