nemo_flow/api/runtime/state.rs
1// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Process-global runtime state and middleware-chain builders.
5//!
6//! [`NemoFlowContextState`] owns the registries and helper methods that power
7//! the public scope, tool, and LLM APIs. Advanced integrations can use this
8//! type directly to register middleware, attach runtime extensions, and build
9//! the resolved callback chains that the higher-level API layer executes.
10
11use std::any::Any;
12use std::collections::HashMap;
13use std::sync::Arc;
14
15use crate::api::event::{
16 BaseEvent, CategoryProfile, Event, EventCategory, MarkEvent, ScopeCategory, ScopeEvent,
17 llm_attributes_to_strings, scope_attributes_to_strings, tool_attributes_to_strings,
18};
19use crate::api::llm::{CreateLlmHandleParams, EndLlmHandleParams};
20use crate::api::llm::{LlmHandle, LlmRequest};
21use crate::api::registry::{ExecutionIntercept, GuardrailEntry, Intercept};
22use crate::api::runtime::callbacks::{
23 EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, LlmRequestInterceptFn,
24 LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn,
25 LlmStreamExecutionRegistryRefs, ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn,
26 ToolInterceptFn, ToolSanitizeFn,
27};
28use crate::api::scope::{CreateScopeHandleParams, EndScopeHandleParams, ScopeHandle};
29use crate::api::tool::ToolHandle;
30use crate::api::tool::{CreateToolHandleParams, EndToolHandleParams};
31use crate::codec::request::AnnotatedLlmRequest;
32use crate::codec::response::AnnotatedLlmResponse;
33use crate::context::registries::{
34 merge_execution_intercept_callables, merge_guardrail_entries, merge_intercept_entries,
35};
36use crate::json::{Json, merge_json};
37use crate::registry::SortedRegistry;
38use chrono::{Duration, Utc};
39
40/// Process-global runtime state backing middleware and event emission.
41///
42/// The public API layer stores one shared instance of this type for the
43/// process. It contains global middleware registries, lifecycle subscribers,
44/// and arbitrary extension slots used by bindings or integrations.
45pub struct NemoFlowContextState {
46 /// Global tool request sanitizers applied to emitted tool-start payloads.
47 pub tool_sanitize_request_guardrails: SortedRegistry<GuardrailEntry<ToolSanitizeFn>>,
48 /// Global tool response sanitizers applied to emitted tool-end payloads.
49 pub tool_sanitize_response_guardrails: SortedRegistry<GuardrailEntry<ToolSanitizeFn>>,
50 /// Global tool guardrails that can reject execution before the callback runs.
51 pub tool_conditional_execution_guardrails: SortedRegistry<GuardrailEntry<ToolConditionalFn>>,
52 /// Global tool request intercepts that can rewrite arguments before execution.
53 pub tool_request_intercepts: SortedRegistry<Intercept<ToolInterceptFn>>,
54 /// Global tool execution intercepts that wrap or replace callback execution.
55 pub tool_execution_intercepts: SortedRegistry<ExecutionIntercept<ToolExecutionFn>>,
56 /// Global LLM request sanitizers applied to emitted LLM-start payloads.
57 pub llm_sanitize_request_guardrails: SortedRegistry<GuardrailEntry<LlmSanitizeRequestFn>>,
58 /// Global LLM response sanitizers applied to emitted LLM-end payloads.
59 pub llm_sanitize_response_guardrails: SortedRegistry<GuardrailEntry<LlmSanitizeResponseFn>>,
60 /// Global LLM guardrails that can reject execution before the provider callback runs.
61 pub llm_conditional_execution_guardrails: SortedRegistry<GuardrailEntry<LlmConditionalFn>>,
62 /// Global LLM request intercepts that can rewrite or annotate requests.
63 pub llm_request_intercepts: SortedRegistry<Intercept<LlmRequestInterceptFn>>,
64 /// Global non-streaming LLM execution intercepts that wrap callback execution.
65 pub llm_execution_intercepts: SortedRegistry<ExecutionIntercept<LlmExecutionFn>>,
66 /// Global streaming LLM execution intercepts that wrap stream-producing callbacks.
67 pub llm_stream_execution_intercepts: SortedRegistry<ExecutionIntercept<LlmStreamExecutionFn>>,
68 /// Global lifecycle subscribers notified after runtime events are emitted.
69 pub event_subscribers: HashMap<String, EventSubscriberFn>,
70 /// Arbitrary binding- or integration-specific runtime extensions.
71 pub extensions: HashMap<String, Box<dyn Any + Send + Sync>>,
72}
73
74impl NemoFlowContextState {
75 /// Create an empty runtime state with no registered middleware.
76 ///
77 /// # Returns
78 /// A [`NemoFlowContextState`] with empty registries, no subscribers, and no
79 /// extensions.
80 pub fn new() -> Self {
81 Self {
82 tool_sanitize_request_guardrails: SortedRegistry::new(|entry| entry.priority),
83 tool_sanitize_response_guardrails: SortedRegistry::new(|entry| entry.priority),
84 tool_conditional_execution_guardrails: SortedRegistry::new(|entry| entry.priority),
85 tool_request_intercepts: SortedRegistry::new(|entry| entry.priority),
86 tool_execution_intercepts: SortedRegistry::new(|entry| entry.priority),
87 llm_sanitize_request_guardrails: SortedRegistry::new(|entry| entry.priority),
88 llm_sanitize_response_guardrails: SortedRegistry::new(|entry| entry.priority),
89 llm_conditional_execution_guardrails: SortedRegistry::new(|entry| entry.priority),
90 llm_request_intercepts: SortedRegistry::new(|entry| entry.priority),
91 llm_execution_intercepts: SortedRegistry::new(|entry| entry.priority),
92 llm_stream_execution_intercepts: SortedRegistry::new(|entry| entry.priority),
93 event_subscribers: HashMap::new(),
94 extensions: HashMap::new(),
95 }
96 }
97
98 /// Store an arbitrary runtime extension under `key`.
99 ///
100 /// Extensions let bindings or integrations attach shared state to the
101 /// process-global runtime without adding new first-class fields.
102 ///
103 /// # Parameters
104 /// - `key`: Stable identifier for the extension slot.
105 /// - `value`: Typed extension value to store.
106 pub fn set_extension<T: Any + Send + Sync>(&mut self, key: impl Into<String>, value: T) {
107 self.extensions.insert(key.into(), Box::new(value));
108 }
109
110 /// Borrow a typed runtime extension by key.
111 ///
112 /// # Parameters
113 /// - `key`: Extension slot name.
114 ///
115 /// # Returns
116 /// `Some(&T)` when an extension exists under `key` with the requested type
117 /// and `None` otherwise.
118 pub fn get_extension<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
119 self.extensions
120 .get(key)
121 .and_then(|value| value.downcast_ref::<T>())
122 }
123
124 /// Mutably borrow a typed runtime extension by key.
125 ///
126 /// # Parameters
127 /// - `key`: Extension slot name.
128 ///
129 /// # Returns
130 /// `Some(&mut T)` when an extension exists under `key` with the requested
131 /// type and `None` otherwise.
132 pub fn get_extension_mut<T: Any + Send + Sync>(&mut self, key: &str) -> Option<&mut T> {
133 self.extensions
134 .get_mut(key)
135 .and_then(|value| value.downcast_mut::<T>())
136 }
137
138 /// Remove a runtime extension by key.
139 ///
140 /// # Parameters
141 /// - `key`: Extension slot name.
142 ///
143 /// # Returns
144 /// `true` when an extension was removed and `false` when no extension was
145 /// stored under `key`.
146 pub fn remove_extension(&mut self, key: &str) -> bool {
147 self.extensions.remove(key).is_some()
148 }
149
150 /// Combine global and scope-local subscribers into one delivery list.
151 ///
152 /// # Parameters
153 /// - `scope_local_subscribers`: Subscribers collected from the active scope
154 /// stack.
155 ///
156 /// # Returns
157 /// A vector containing all global subscribers followed by the provided
158 /// scope-local subscribers.
159 pub fn collect_event_subscribers(
160 &self,
161 scope_local_subscribers: &[EventSubscriberFn],
162 ) -> Vec<EventSubscriberFn> {
163 let mut subscribers =
164 Vec::with_capacity(self.event_subscribers.len() + scope_local_subscribers.len());
165 subscribers.extend(self.event_subscribers.values().cloned());
166 subscribers.extend(scope_local_subscribers.iter().cloned());
167 subscribers
168 }
169
170 /// Deliver an event to every subscriber in order.
171 ///
172 /// # Parameters
173 /// - `event`: Fully constructed lifecycle event to deliver.
174 /// - `subscribers`: Subscribers that should observe the event.
175 pub fn emit_event(event: &Event, subscribers: &[EventSubscriberFn]) {
176 for subscriber in subscribers {
177 subscriber(event);
178 }
179 }
180
181 /// Build a standalone mark event.
182 ///
183 /// # Parameters
184 /// - `params`: A pre-built [`MarkEvent`] to wrap in an [`Event`].
185 ///
186 /// # Returns
187 /// A mark [`Event`] containing the provided [`MarkEvent`].
188 pub fn create_event(&self, params: MarkEvent) -> Event {
189 Event::Mark(params)
190 }
191
192 /// Create a new scope handle.
193 ///
194 /// # Parameters
195 /// - `name`: Human-readable scope name.
196 /// - `parent_uuid`: Optional parent scope UUID.
197 /// - `scope_type`: Semantic category of the scope.
198 /// - `attributes`: Scope attribute bitflags.
199 /// - `data`: Optional application payload stored on the handle.
200 /// - `metadata`: Optional metadata stored on the handle.
201 /// - `timestamp`: Optional handle start time. When omitted, the current
202 /// UTC time is used.
203 ///
204 /// # Returns
205 /// A new [`ScopeHandle`] with a fresh UUID.
206 pub fn create_scope_handle(&self, params: CreateScopeHandleParams<'_>) -> ScopeHandle {
207 ScopeHandle::builder()
208 .name(params.name)
209 .scope_type(params.scope_type)
210 .started_at(params.timestamp.unwrap_or_else(Utc::now))
211 .attributes(params.attributes)
212 .parent_uuid_opt(params.parent_uuid)
213 .data_opt(params.data)
214 .metadata_opt(params.metadata)
215 .build()
216 }
217
218 /// Build a scope-start event from a handle.
219 ///
220 /// # Parameters
221 /// - `handle`: Scope handle to serialize into an event.
222 /// - `data`: Optional semantic input payload exported on the start event.
223 ///
224 /// # Returns
225 /// A scope-start [`Event`] derived from the provided handle.
226 pub fn build_scope_start_event(&self, handle: &ScopeHandle, data: Option<Json>) -> Event {
227 Event::Scope(ScopeEvent::new(
228 BaseEvent::builder()
229 .parent_uuid_opt(handle.parent_uuid)
230 .uuid(handle.uuid)
231 .timestamp(handle.started_at)
232 .name(handle.name.as_str())
233 .data_opt(data)
234 .metadata_opt(handle.metadata.clone())
235 .build(),
236 ScopeCategory::Start,
237 scope_attributes_to_strings(handle.attributes),
238 EventCategory::from(handle.scope_type),
239 None,
240 ))
241 }
242
243 /// Build a scope-end event from a handle.
244 ///
245 /// # Parameters
246 /// - `handle`: Scope handle to serialize into an event.
247 /// - `data`: Optional data payload returned from the scope.
248 ///
249 /// # Returns
250 /// A scope-end [`Event`] derived from the provided handle.
251 pub fn end_scope_handle(&self, handle: &ScopeHandle, data: Option<Json>) -> Event {
252 self.build_scope_end_event(
253 EndScopeHandleParams::builder()
254 .handle(handle)
255 .data_opt(data)
256 .build(),
257 )
258 }
259
260 /// Build a scope-end event from builder parameters.
261 ///
262 /// # Parameters
263 /// - `params`: Scope end-event builder parameters.
264 ///
265 /// # Returns
266 /// A scope-end [`Event`] derived from the provided parameters.
267 pub fn build_scope_end_event(&self, params: EndScopeHandleParams<'_>) -> Event {
268 let handle = params.handle;
269 Event::Scope(ScopeEvent::new(
270 BaseEvent::builder()
271 .parent_uuid_opt(handle.parent_uuid)
272 .uuid(handle.uuid)
273 .timestamp(
274 params
275 .timestamp
276 .unwrap_or_else(|| end_timestamp_after(handle.started_at)),
277 )
278 .name(handle.name.as_str())
279 .data_opt(params.data)
280 .metadata_opt(handle.metadata.clone())
281 .build(),
282 ScopeCategory::End,
283 scope_attributes_to_strings(handle.attributes),
284 EventCategory::from(handle.scope_type),
285 None,
286 ))
287 }
288
289 /// Create a new tool handle.
290 ///
291 /// # Parameters
292 /// - `name`: Tool name recorded on emitted events.
293 /// - `parent_uuid`: Optional parent scope UUID.
294 /// - `attributes`: Tool attribute bitflags.
295 /// - `data`: Optional application payload stored on the handle.
296 /// - `metadata`: Optional metadata stored on the handle.
297 /// - `tool_call_id`: Optional provider-specific correlation identifier.
298 /// - `timestamp`: Optional handle start time. When omitted, the current
299 /// UTC time is used.
300 ///
301 /// # Returns
302 /// A new [`ToolHandle`] with a fresh UUID.
303 pub fn create_tool_handle(&self, params: CreateToolHandleParams<'_>) -> ToolHandle {
304 ToolHandle::builder()
305 .name(params.name)
306 .started_at(params.timestamp.unwrap_or_else(Utc::now))
307 .attributes(params.attributes)
308 .parent_uuid_opt(params.parent_uuid)
309 .data_opt(params.data)
310 .metadata_opt(params.metadata)
311 .tool_call_id_opt(params.tool_call_id)
312 .build()
313 }
314
315 /// Build a tool-start event from a handle.
316 ///
317 /// # Parameters
318 /// - `handle`: Tool handle to serialize into an event.
319 /// - `data`: Optional tool input payload.
320 ///
321 /// # Returns
322 /// A tool-start [`Event`] derived from the provided handle.
323 pub fn build_tool_start_event(&self, handle: &ToolHandle, data: Option<Json>) -> Event {
324 Event::Scope(ScopeEvent::new(
325 BaseEvent::builder()
326 .parent_uuid_opt(handle.parent_uuid)
327 .uuid(handle.uuid)
328 .timestamp(handle.started_at)
329 .name(handle.name.as_str())
330 .data_opt(data)
331 .metadata_opt(handle.metadata.clone())
332 .build(),
333 ScopeCategory::Start,
334 tool_attributes_to_strings(handle.attributes),
335 EventCategory::tool(),
336 Some(
337 CategoryProfile::builder()
338 .tool_call_id_opt(handle.tool_call_id.clone())
339 .build(),
340 ),
341 ))
342 }
343
344 /// Build a tool-end event from a handle and optional overrides.
345 ///
346 /// # Parameters
347 /// - `handle`: Tool handle to serialize into an event.
348 /// - `data`: Optional end-event data payload.
349 /// - `metadata`: Optional metadata payload merged over `handle.metadata`.
350 ///
351 /// # Returns
352 /// A tool-end [`Event`] derived from the provided handle.
353 pub fn end_tool_handle(
354 &self,
355 handle: &ToolHandle,
356 data: Option<Json>,
357 metadata: Option<Json>,
358 ) -> Event {
359 self.build_tool_end_event(
360 EndToolHandleParams::builder()
361 .handle(handle)
362 .data_opt(data)
363 .metadata_opt(metadata)
364 .build(),
365 )
366 }
367
368 /// Build a tool-end event from builder parameters.
369 ///
370 /// The `metadata` payload is merged over the metadata already stored on
371 /// the handle.
372 ///
373 /// # Parameters
374 /// - `params`: Tool end-event builder parameters.
375 ///
376 /// # Returns
377 /// A tool-end [`Event`] derived from the provided parameters.
378 pub fn build_tool_end_event(&self, params: EndToolHandleParams<'_>) -> Event {
379 let handle = params.handle;
380 Event::Scope(ScopeEvent::new(
381 BaseEvent::builder()
382 .parent_uuid_opt(handle.parent_uuid)
383 .uuid(handle.uuid)
384 .timestamp(
385 params
386 .timestamp
387 .unwrap_or_else(|| end_timestamp_after(handle.started_at)),
388 )
389 .name(handle.name.as_str())
390 .data_opt(params.data)
391 .metadata_opt(merge_json(handle.metadata.clone(), params.metadata))
392 .build(),
393 ScopeCategory::End,
394 tool_attributes_to_strings(handle.attributes),
395 EventCategory::tool(),
396 Some(
397 CategoryProfile::builder()
398 .tool_call_id_opt(handle.tool_call_id.clone())
399 .build(),
400 ),
401 ))
402 }
403
404 /// Create a new LLM handle.
405 ///
406 /// # Parameters
407 /// - `name`: Logical provider or model family name.
408 /// - `parent_uuid`: Optional parent scope UUID.
409 /// - `attributes`: LLM attribute bitflags.
410 /// - `data`: Optional application payload stored on the handle.
411 /// - `metadata`: Optional metadata stored on the handle.
412 /// - `model_name`: Optional normalized model name stored on the handle.
413 /// - `timestamp`: Optional handle start time. When omitted, the current
414 /// UTC time is used.
415 ///
416 /// # Returns
417 /// A new [`LlmHandle`] with a fresh UUID.
418 pub fn create_llm_handle(&self, params: CreateLlmHandleParams<'_>) -> LlmHandle {
419 LlmHandle::builder()
420 .name(params.name)
421 .started_at(params.timestamp.unwrap_or_else(Utc::now))
422 .attributes(params.attributes)
423 .parent_uuid_opt(params.parent_uuid)
424 .data_opt(params.data)
425 .metadata_opt(params.metadata)
426 .model_name_opt(params.model_name)
427 .build()
428 }
429
430 /// Build an LLM-start event from a handle.
431 ///
432 /// # Parameters
433 /// - `handle`: LLM handle to serialize into an event.
434 /// - `data`: Sanitized LLM request payload.
435 /// - `annotated_request`: Optional normalized request annotation.
436 ///
437 /// # Returns
438 /// An LLM-start [`Event`] derived from the provided handle.
439 pub fn build_llm_start_event(
440 &self,
441 handle: &LlmHandle,
442 data: Option<Json>,
443 annotated_request: Option<Arc<AnnotatedLlmRequest>>,
444 ) -> Event {
445 Event::Scope(ScopeEvent::new(
446 BaseEvent::builder()
447 .parent_uuid_opt(handle.parent_uuid)
448 .uuid(handle.uuid)
449 .timestamp(handle.started_at)
450 .name(handle.name.as_str())
451 .data_opt(data)
452 .metadata_opt(handle.metadata.clone())
453 .build(),
454 ScopeCategory::Start,
455 llm_attributes_to_strings(handle.attributes),
456 EventCategory::llm(),
457 Some(
458 CategoryProfile::builder()
459 .model_name_opt(handle.model_name.clone())
460 .annotated_request_opt(annotated_request)
461 .build(),
462 ),
463 ))
464 }
465
466 /// Build an LLM-end event from a handle and optional overrides.
467 ///
468 /// # Parameters
469 /// - `handle`: LLM handle to serialize into an event.
470 /// - `data`: Sanitized LLM response payload.
471 /// - `metadata`: Optional metadata payload merged over `handle.metadata`.
472 /// - `annotated_response`: Optional normalized response annotation.
473 ///
474 /// # Returns
475 /// An LLM-end [`Event`] derived from the provided handle.
476 pub fn end_llm_handle(
477 &self,
478 handle: &LlmHandle,
479 data: Option<Json>,
480 metadata: Option<Json>,
481 annotated_response: Option<Arc<AnnotatedLlmResponse>>,
482 ) -> Event {
483 self.build_llm_end_event(
484 EndLlmHandleParams::builder()
485 .handle(handle)
486 .data_opt(data)
487 .metadata_opt(metadata)
488 .annotated_response_opt(annotated_response)
489 .build(),
490 )
491 }
492
493 /// Build an LLM-end event from builder parameters.
494 ///
495 /// The `metadata` payload is merged over the metadata already stored on
496 /// the handle.
497 ///
498 /// # Parameters
499 /// - `params`: LLM end-event builder parameters.
500 ///
501 /// # Returns
502 /// An LLM-end [`Event`] derived from the provided parameters.
503 pub fn build_llm_end_event(&self, params: EndLlmHandleParams<'_>) -> Event {
504 let handle = params.handle;
505 Event::Scope(ScopeEvent::new(
506 BaseEvent::builder()
507 .parent_uuid_opt(handle.parent_uuid)
508 .uuid(handle.uuid)
509 .timestamp(
510 params
511 .timestamp
512 .unwrap_or_else(|| end_timestamp_after(handle.started_at)),
513 )
514 .name(handle.name.as_str())
515 .data_opt(params.data)
516 .metadata_opt(merge_json(handle.metadata.clone(), params.metadata))
517 .build(),
518 ScopeCategory::End,
519 llm_attributes_to_strings(handle.attributes),
520 EventCategory::llm(),
521 Some(
522 CategoryProfile::builder()
523 .model_name_opt(handle.model_name.clone())
524 .annotated_response_opt(params.annotated_response)
525 .build(),
526 ),
527 ))
528 }
529
530 /// Run tool request sanitizers across global and scope-local registries.
531 ///
532 /// # Parameters
533 /// - `name`: Tool name associated with the request.
534 /// - `args`: Raw tool arguments to sanitize for observability.
535 /// - `scope_locals`: Scope-local sanitizer registries collected from the
536 /// active scope stack.
537 ///
538 /// # Returns
539 /// The sanitized JSON payload after every matching guardrail has run.
540 pub fn tool_sanitize_request_chain(
541 &self,
542 name: &str,
543 args: Json,
544 scope_locals: &[&SortedRegistry<GuardrailEntry<ToolSanitizeFn>>],
545 ) -> Json {
546 let entries = merge_guardrail_entries(&self.tool_sanitize_request_guardrails, scope_locals);
547 let mut value = args;
548 for entry in entries {
549 value = (entry.guardrail)(name, value);
550 }
551 value
552 }
553
554 /// Run tool response sanitizers across global and scope-local registries.
555 ///
556 /// # Parameters
557 /// - `name`: Tool name associated with the response.
558 /// - `result`: Raw tool result to sanitize for observability.
559 /// - `scope_locals`: Scope-local sanitizer registries collected from the
560 /// active scope stack.
561 ///
562 /// # Returns
563 /// The sanitized JSON payload after every matching guardrail has run.
564 pub fn tool_sanitize_response_chain(
565 &self,
566 name: &str,
567 result: Json,
568 scope_locals: &[&SortedRegistry<GuardrailEntry<ToolSanitizeFn>>],
569 ) -> Json {
570 let entries =
571 merge_guardrail_entries(&self.tool_sanitize_response_guardrails, scope_locals);
572 let mut value = result;
573 for entry in entries {
574 value = (entry.guardrail)(name, value);
575 }
576 value
577 }
578
579 /// Evaluate tool conditional-execution guardrails in priority order.
580 ///
581 /// # Parameters
582 /// - `name`: Tool name associated with the request.
583 /// - `args`: Tool arguments to validate.
584 /// - `scope_locals`: Scope-local conditional guardrail registries collected
585 /// from the active scope stack.
586 ///
587 /// # Returns
588 /// A [`Result`] containing `Ok(None)` when execution is allowed or
589 /// `Ok(Some(reason))` when a guardrail rejects the call.
590 ///
591 /// # Errors
592 /// Propagates any error returned by a guardrail callback.
593 pub fn tool_conditional_execution_chain(
594 &self,
595 name: &str,
596 args: &Json,
597 scope_locals: &[&SortedRegistry<GuardrailEntry<ToolConditionalFn>>],
598 ) -> crate::error::Result<Option<String>> {
599 let entries =
600 merge_guardrail_entries(&self.tool_conditional_execution_guardrails, scope_locals);
601 for entry in entries {
602 if let Some(error) = (entry.guardrail)(name, args)? {
603 return Ok(Some(error));
604 }
605 }
606 Ok(None)
607 }
608
609 /// Run tool request intercepts in priority order.
610 ///
611 /// # Parameters
612 /// - `name`: Tool name associated with the request.
613 /// - `args`: Tool arguments to pass through the intercept chain.
614 /// - `scope_locals`: Scope-local request intercept registries collected
615 /// from the active scope stack.
616 ///
617 /// # Returns
618 /// A [`Result`] containing the final JSON argument payload.
619 ///
620 /// # Errors
621 /// Propagates any error returned by an intercept callback.
622 ///
623 /// # Notes
624 /// If an intercept entry has `break_chain` enabled, later intercepts are
625 /// skipped after that entry runs.
626 pub fn tool_request_intercepts_chain(
627 &self,
628 name: &str,
629 args: Json,
630 scope_locals: &[&SortedRegistry<Intercept<ToolInterceptFn>>],
631 ) -> crate::error::Result<Json> {
632 let entries = merge_intercept_entries(&self.tool_request_intercepts, scope_locals);
633 let mut value = args;
634 for entry in entries {
635 value = (entry.callable)(name, value)?;
636 if entry.break_chain {
637 break;
638 }
639 }
640 Ok(value)
641 }
642
643 /// Build the composed tool execution continuation chain.
644 ///
645 /// # Parameters
646 /// - `name`: Tool name passed into each execution intercept.
647 /// - `default_fn`: Base tool callback that should run after all intercepts.
648 /// - `scope_locals`: Scope-local execution intercept registries collected
649 /// from the active scope stack.
650 ///
651 /// # Returns
652 /// A composed [`ToolExecutionNextFn`] that wraps `default_fn` in every
653 /// matching execution intercept.
654 pub fn tool_build_execution_chain(
655 &self,
656 name: &str,
657 default_fn: ToolExecutionNextFn,
658 scope_locals: &[&SortedRegistry<ExecutionIntercept<ToolExecutionFn>>],
659 ) -> ToolExecutionNextFn {
660 let matching =
661 merge_execution_intercept_callables(&self.tool_execution_intercepts, scope_locals);
662 let mut next = default_fn;
663 let name = name.to_string();
664 for (callable, _) in matching.into_iter().rev() {
665 let current_next = next.clone();
666 let current_name = name.clone();
667 next = Arc::new(move |args| callable(¤t_name, args, current_next.clone()));
668 }
669 next
670 }
671
672 /// Run LLM request sanitizers across global and scope-local registries.
673 ///
674 /// # Parameters
675 /// - `request`: Raw LLM request to sanitize for observability.
676 /// - `scope_locals`: Scope-local sanitizer registries collected from the
677 /// active scope stack.
678 ///
679 /// # Returns
680 /// The sanitized [`LlmRequest`] after every matching guardrail has run.
681 pub fn llm_sanitize_request_chain(
682 &self,
683 request: LlmRequest,
684 scope_locals: &[&SortedRegistry<GuardrailEntry<LlmSanitizeRequestFn>>],
685 ) -> LlmRequest {
686 let entries = merge_guardrail_entries(&self.llm_sanitize_request_guardrails, scope_locals);
687 let mut value = request;
688 for entry in entries {
689 value = (entry.guardrail)(value);
690 }
691 value
692 }
693
694 /// Run LLM response sanitizers across global and scope-local registries.
695 ///
696 /// # Parameters
697 /// - `response`: Raw response payload to sanitize for observability.
698 /// - `scope_locals`: Scope-local sanitizer registries collected from the
699 /// active scope stack.
700 ///
701 /// # Returns
702 /// The sanitized response payload after every matching guardrail has run.
703 pub fn llm_sanitize_response_chain(
704 &self,
705 response: Json,
706 scope_locals: &[&SortedRegistry<GuardrailEntry<LlmSanitizeResponseFn>>],
707 ) -> Json {
708 let entries = merge_guardrail_entries(&self.llm_sanitize_response_guardrails, scope_locals);
709 let mut value = response;
710 for entry in entries {
711 value = (entry.guardrail)(value);
712 }
713 value
714 }
715
716 /// Evaluate LLM conditional-execution guardrails in priority order.
717 ///
718 /// # Parameters
719 /// - `request`: LLM request to validate.
720 /// - `scope_locals`: Scope-local conditional guardrail registries collected
721 /// from the active scope stack.
722 ///
723 /// # Returns
724 /// A [`Result`] containing `Ok(None)` when execution is allowed or
725 /// `Ok(Some(reason))` when a guardrail rejects the call.
726 ///
727 /// # Errors
728 /// Propagates any error returned by a guardrail callback.
729 pub fn llm_conditional_execution_chain(
730 &self,
731 request: &LlmRequest,
732 scope_locals: &[&SortedRegistry<GuardrailEntry<LlmConditionalFn>>],
733 ) -> crate::error::Result<Option<String>> {
734 let entries =
735 merge_guardrail_entries(&self.llm_conditional_execution_guardrails, scope_locals);
736 for entry in entries {
737 if let Some(error) = (entry.guardrail)(request)? {
738 return Ok(Some(error));
739 }
740 }
741 Ok(None)
742 }
743
744 /// Run LLM request intercepts in priority order.
745 ///
746 /// # Parameters
747 /// - `name`: Logical provider or model family name.
748 /// - `request`: LLM request to pass through the intercept chain.
749 /// - `annotated`: Optional normalized request annotation to carry through
750 /// the chain.
751 /// - `scope_locals`: Scope-local request intercept registries collected
752 /// from the active scope stack.
753 ///
754 /// # Returns
755 /// A [`Result`] containing the final request and annotation pair.
756 ///
757 /// # Errors
758 /// Propagates any error returned by an intercept callback.
759 ///
760 /// # Notes
761 /// If an intercept entry has `break_chain` enabled, later intercepts are
762 /// skipped after that entry runs.
763 pub fn llm_request_intercepts_chain(
764 &self,
765 name: &str,
766 request: LlmRequest,
767 annotated: Option<AnnotatedLlmRequest>,
768 scope_locals: &[&SortedRegistry<Intercept<LlmRequestInterceptFn>>],
769 ) -> crate::error::Result<(LlmRequest, Option<AnnotatedLlmRequest>)> {
770 let entries = merge_intercept_entries(&self.llm_request_intercepts, scope_locals);
771 let mut request_value = request;
772 let mut annotated_value = annotated;
773 for entry in entries {
774 let (new_request, new_annotated) =
775 (entry.callable)(name, request_value, annotated_value)?;
776 request_value = new_request;
777 annotated_value = new_annotated;
778 if entry.break_chain {
779 break;
780 }
781 }
782 Ok((request_value, annotated_value))
783 }
784
785 /// Build the composed non-streaming LLM execution continuation chain.
786 ///
787 /// # Parameters
788 /// - `name`: Logical provider or model family name passed into each
789 /// execution intercept.
790 /// - `default_fn`: Base provider callback that should run after all
791 /// intercepts.
792 /// - `scope_locals`: Scope-local execution intercept registries collected
793 /// from the active scope stack.
794 ///
795 /// # Returns
796 /// A composed [`LlmExecutionNextFn`] that wraps `default_fn` in every
797 /// matching execution intercept.
798 pub fn llm_build_execution_chain(
799 &self,
800 name: &str,
801 default_fn: LlmExecutionNextFn,
802 scope_locals: &[&SortedRegistry<ExecutionIntercept<LlmExecutionFn>>],
803 ) -> LlmExecutionNextFn {
804 let matching =
805 merge_execution_intercept_callables(&self.llm_execution_intercepts, scope_locals);
806 let mut next = default_fn;
807 let name = name.to_string();
808 for (callable, _) in matching.into_iter().rev() {
809 let current_next = next.clone();
810 let current_name = name.clone();
811 next = Arc::new(move |request| callable(¤t_name, request, current_next.clone()));
812 }
813 next
814 }
815
816 /// Build the composed streaming LLM execution continuation chain.
817 ///
818 /// # Parameters
819 /// - `name`: Logical provider or model family name passed into each
820 /// execution intercept.
821 /// - `default_fn`: Base stream-producing callback that should run after all
822 /// intercepts.
823 /// - `scope_locals`: Scope-local execution intercept registries collected
824 /// from the active scope stack.
825 ///
826 /// # Returns
827 /// A composed [`LlmStreamExecutionNextFn`] that wraps `default_fn` in every
828 /// matching execution intercept.
829 pub fn llm_stream_build_execution_chain(
830 &self,
831 name: &str,
832 default_fn: LlmStreamExecutionNextFn,
833 scope_locals: LlmStreamExecutionRegistryRefs<'_>,
834 ) -> LlmStreamExecutionNextFn {
835 let matching = merge_execution_intercept_callables(
836 &self.llm_stream_execution_intercepts,
837 scope_locals,
838 );
839 let mut next = default_fn;
840 let name = name.to_string();
841 for (callable, _) in matching.into_iter().rev() {
842 let current_next = next.clone();
843 let current_name = name.clone();
844 next = Arc::new(move |request| callable(¤t_name, request, current_next.clone()));
845 }
846 next
847 }
848}
849
850fn end_timestamp_after(started_at: chrono::DateTime<Utc>) -> chrono::DateTime<Utc> {
851 let now = Utc::now();
852 if now > started_at {
853 now
854 } else {
855 started_at + Duration::microseconds(1)
856 }
857}
858
859impl Default for NemoFlowContextState {
860 fn default() -> Self {
861 Self::new()
862 }
863}