1use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use bitflags::bitflags;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11use typed_builder::TypedBuilder;
12use uuid::Uuid;
13
14use crate::api::runtime::NemoFlowContextState;
15use crate::api::runtime::current_scope_stack;
16use crate::api::runtime::global_context;
17use crate::api::runtime::{
18 LlmCollectorFn, LlmExecutionNextFn, LlmFinalizerFn, LlmJsonStream, LlmStreamExecutionNextFn,
19};
20use crate::api::scope::event;
21use crate::api::scope::{EmitMarkEventParams, ScopeHandle};
22use crate::api::shared::{
23 ensure_runtime_owner, resolve_parent_uuid, run_request_intercepts_with_codec,
24 snapshot_event_subscribers,
25};
26use crate::codec::request::AnnotatedLlmRequest;
27use crate::codec::response::AnnotatedLlmResponse;
28use crate::codec::traits::{LlmCodec, LlmResponseCodec};
29use crate::error::{FlowError, Result};
30use crate::json::Json;
31use crate::stream::LlmStreamWrapper;
32
33bitflags! {
34 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
36 pub struct LlmAttributes: u32 {
37 const STATEFUL = 0b01;
39 const STREAMING = 0b10;
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
46#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
47pub struct LlmHandle {
48 #[builder(default = Uuid::now_v7())]
50 pub uuid: Uuid,
51 #[builder(default = Utc::now())]
53 pub started_at: DateTime<Utc>,
54 #[builder(setter(into))]
56 pub name: String,
57 #[builder(default)]
59 pub data: Option<Json>,
60 #[builder(default)]
62 pub metadata: Option<Json>,
63 #[builder(default = LlmAttributes::empty())]
65 pub attributes: LlmAttributes,
66 #[builder(default)]
68 pub parent_uuid: Option<Uuid>,
69 #[builder(default, setter(into))]
71 pub model_name: Option<String>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct LlmRequest {
77 pub headers: serde_json::Map<String, Json>,
79 pub content: Json,
81}
82
83#[derive(Debug, Clone, TypedBuilder)]
85#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
86pub struct CreateLlmHandleParams<'a> {
87 pub name: &'a str,
89 #[builder(default)]
91 pub parent_uuid: Option<uuid::Uuid>,
92 #[builder(default = LlmAttributes::empty())]
94 pub attributes: LlmAttributes,
95 #[builder(default)]
97 pub data: Option<Json>,
98 #[builder(default)]
100 pub metadata: Option<Json>,
101 #[builder(default, setter(into))]
103 pub model_name: Option<String>,
104 #[builder(default)]
107 pub timestamp: Option<DateTime<Utc>>,
108}
109
110#[derive(Clone, TypedBuilder)]
112#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
113pub struct EndLlmHandleParams<'a> {
114 pub handle: &'a LlmHandle,
116 #[builder(default)]
118 pub data: Option<Json>,
119 #[builder(default)]
121 pub metadata: Option<Json>,
122 #[builder(default)]
124 pub annotated_response: Option<Arc<AnnotatedLlmResponse>>,
125 #[builder(default)]
129 pub timestamp: Option<DateTime<Utc>>,
130}
131
132#[derive(TypedBuilder)]
134#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
135pub struct LlmCallParams<'a> {
136 pub name: &'a str,
138 pub request: &'a LlmRequest,
140 #[builder(default)]
142 pub parent: Option<&'a ScopeHandle>,
143 #[builder(default = LlmAttributes::empty())]
145 pub attributes: LlmAttributes,
146 #[builder(default)]
148 pub data: Option<Json>,
149 #[builder(default)]
151 pub metadata: Option<Json>,
152 #[builder(default, setter(into))]
154 pub model_name: Option<String>,
155 #[builder(default)]
157 pub annotated_request: Option<Arc<AnnotatedLlmRequest>>,
158 #[builder(default)]
161 pub timestamp: Option<DateTime<Utc>>,
162}
163
164#[derive(TypedBuilder)]
166#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
167pub struct LlmCallExecuteParams {
168 #[builder(setter(into))]
170 pub name: String,
171 pub request: LlmRequest,
173 pub func: LlmExecutionNextFn,
175 #[builder(default)]
177 pub parent: Option<ScopeHandle>,
178 #[builder(default = LlmAttributes::empty())]
180 pub attributes: LlmAttributes,
181 #[builder(default)]
183 pub data: Option<Json>,
184 #[builder(default)]
186 pub metadata: Option<Json>,
187 #[builder(default, setter(into))]
189 pub model_name: Option<String>,
190 #[builder(default)]
192 pub codec: Option<Arc<dyn LlmCodec>>,
193 #[builder(default)]
195 pub response_codec: Option<Arc<dyn LlmResponseCodec>>,
196}
197
198#[derive(TypedBuilder)]
200#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
201pub struct LlmStreamCallExecuteParams {
202 #[builder(setter(into))]
204 pub name: String,
205 pub request: LlmRequest,
207 pub func: LlmStreamExecutionNextFn,
209 pub collector: LlmCollectorFn,
211 pub finalizer: LlmFinalizerFn,
213 #[builder(default)]
215 pub parent: Option<ScopeHandle>,
216 #[builder(default = LlmAttributes::empty())]
218 pub attributes: LlmAttributes,
219 #[builder(default)]
221 pub data: Option<Json>,
222 #[builder(default)]
224 pub metadata: Option<Json>,
225 #[builder(default, setter(into))]
227 pub model_name: Option<String>,
228 #[builder(default)]
230 pub codec: Option<Arc<dyn LlmCodec>>,
231 #[builder(default)]
233 pub response_codec: Option<Arc<dyn LlmResponseCodec>>,
234}
235
236#[derive(TypedBuilder)]
238#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
239pub struct LlmCallEndParams<'a> {
240 pub handle: &'a LlmHandle,
242 pub response: Json,
244 #[builder(default)]
246 pub data: Option<Json>,
247 #[builder(default)]
249 pub metadata: Option<Json>,
250 #[builder(default)]
252 pub annotated_response: Option<Arc<AnnotatedLlmResponse>>,
253 #[builder(default)]
257 pub timestamp: Option<DateTime<Utc>>,
258}
259
260fn create_llm_handle(params: CreateLlmHandleParams<'_>) -> Result<LlmHandle> {
261 ensure_runtime_owner()?;
262 let context = global_context();
263 let state = context
264 .read()
265 .map_err(|error| FlowError::Internal(error.to_string()))?;
266 Ok(state.create_llm_handle(params))
267}
268
269fn emit_llm_start(
270 handle: &LlmHandle,
271 request: &LlmRequest,
272 annotated_request: Option<Arc<AnnotatedLlmRequest>>,
273) -> Result<()> {
274 ensure_runtime_owner()?;
275 let (event, subscribers) = {
276 let scope_stack = current_scope_stack();
277 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
278 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
279 ®istries.llm_sanitize_request_guardrails
280 });
281 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
282 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
283 let context = global_context();
284 let state = context
285 .read()
286 .map_err(|error| FlowError::Internal(error.to_string()))?;
287
288 let sanitized_request = state.llm_sanitize_request_chain(request.clone(), &scope_locals);
289 let input = serde_json::to_value(&sanitized_request).unwrap_or(Json::Null);
290 let event = state.build_llm_start_event(handle, Some(input), annotated_request);
291 (event, subscribers)
292 };
293 NemoFlowContextState::emit_event(&event, &subscribers);
294 Ok(())
295}
296
297fn emit_llm_start_once(
298 start_emitted: &Arc<AtomicBool>,
299 handle: &LlmHandle,
300 request: &LlmRequest,
301 annotated_request: Option<Arc<AnnotatedLlmRequest>>,
302) -> Result<()> {
303 if start_emitted.swap(true, Ordering::SeqCst) {
304 return Ok(());
305 }
306 emit_llm_start(handle, request, annotated_request)
307}
308
309pub fn llm_call(params: LlmCallParams<'_>) -> Result<LlmHandle> {
340 let handle_params = CreateLlmHandleParams::builder()
341 .name(params.name)
342 .parent_uuid_opt(resolve_parent_uuid(params.parent))
343 .attributes(params.attributes)
344 .data_opt(params.data)
345 .metadata_opt(params.metadata)
346 .model_name_opt(params.model_name)
347 .timestamp_opt(params.timestamp)
348 .build();
349 let handle = create_llm_handle(handle_params)?;
350 emit_llm_start(&handle, params.request, params.annotated_request)?;
351 Ok(handle)
352}
353
354pub fn llm_call_end(params: LlmCallEndParams<'_>) -> Result<()> {
383 ensure_runtime_owner()?;
384 let (event, subscribers) = {
385 let scope_stack = current_scope_stack();
386 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
387 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
388 ®istries.llm_sanitize_response_guardrails
389 });
390 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
391 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
392 let context = global_context();
393 let state = context
394 .read()
395 .map_err(|error| FlowError::Internal(error.to_string()))?;
396
397 let sanitized_response = state.llm_sanitize_response_chain(params.response, &scope_locals);
398 let data = if sanitized_response.is_null() {
399 params.data
400 } else {
401 Some(sanitized_response)
402 };
403 let event = state.build_llm_end_event(
404 EndLlmHandleParams::builder()
405 .handle(params.handle)
406 .data_opt(data)
407 .metadata_opt(params.metadata)
408 .annotated_response_opt(params.annotated_response)
409 .timestamp_opt(params.timestamp)
410 .build(),
411 );
412 (event, subscribers)
413 };
414 NemoFlowContextState::emit_event(&event, &subscribers);
415 Ok(())
416}
417
418fn emit_llm_end_without_output(handle: &LlmHandle, metadata: Option<Json>) -> Result<()> {
419 ensure_runtime_owner()?;
420 let (event, subscribers) = {
421 let scope_stack = current_scope_stack();
422 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
423 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
424 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
425 let context = global_context();
426 let state = context
427 .read()
428 .map_err(|error| FlowError::Internal(error.to_string()))?;
429 let event = state.end_llm_handle(handle, handle.data.clone(), metadata, None);
430 (event, subscribers)
431 };
432 NemoFlowContextState::emit_event(&event, &subscribers);
433 Ok(())
434}
435
436pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result<Json> {
470 let LlmCallExecuteParams {
471 name,
472 request,
473 func,
474 parent,
475 attributes,
476 data,
477 metadata,
478 model_name,
479 codec,
480 response_codec,
481 } = params;
482 ensure_runtime_owner()?;
483 {
484 let scope_stack = current_scope_stack();
485 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
486 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
487 ®istries.llm_conditional_execution_guardrails
488 });
489 let context = global_context();
490 let state = context
491 .read()
492 .map_err(|error| FlowError::Internal(error.to_string()))?;
493 if let Some(error) = state.llm_conditional_execution_chain(&request, &scope_locals)? {
494 drop(state);
495 drop(scope_guard);
496 let mut rejection_data = json!({});
497 if let Some(object) = rejection_data.as_object_mut() {
498 object.insert("rejected".into(), json!(true));
499 object.insert("rejection_reason".into(), json!(&error));
500 }
501 let _ = event(
502 EmitMarkEventParams::builder()
503 .name(&name)
504 .parent_opt(parent.as_ref())
505 .data(rejection_data)
506 .metadata_opt(metadata.clone())
507 .build(),
508 );
509 return Err(FlowError::GuardrailRejected(error));
510 }
511 }
512
513 let (intercepted_request, annotated_request) =
514 run_request_intercepts_with_codec(&name, request, codec)?;
515
516 let handle = create_llm_handle(
517 CreateLlmHandleParams::builder()
518 .name(name.as_str())
519 .parent_uuid_opt(resolve_parent_uuid(parent.as_ref()))
520 .attributes(attributes)
521 .data_opt(data.clone())
522 .metadata_opt(metadata.clone())
523 .model_name_opt(model_name)
524 .build(),
525 )?;
526 let start_emitted = Arc::new(AtomicBool::new(false));
527 let fallback_request = intercepted_request.clone();
528 let execution_handle = handle.clone();
529 let execution_annotated_request = annotated_request.clone();
530 let execution_start_emitted = start_emitted.clone();
531 let instrumented_func: LlmExecutionNextFn = Arc::new(move |request| {
532 let next = func.clone();
533 let handle = execution_handle.clone();
534 let annotated_request = execution_annotated_request.clone();
535 let start_emitted = execution_start_emitted.clone();
536 Box::pin(async move {
537 emit_llm_start_once(&start_emitted, &handle, &request, annotated_request)?;
538 next(request).await
539 })
540 });
541
542 let execution = {
543 let scope_stack = current_scope_stack();
544 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
545 let scope_locals = scope_guard
546 .collect_scope_local_registries(|registries| ®istries.llm_execution_intercepts);
547 let context = global_context();
548 let state = context
549 .read()
550 .map_err(|error| FlowError::Internal(error.to_string()))?;
551 state.llm_build_execution_chain(&name, instrumented_func, &scope_locals)
552 };
553
554 match execution(intercepted_request).await {
555 Ok(response) => {
556 emit_llm_start_once(
557 &start_emitted,
558 &handle,
559 &fallback_request,
560 annotated_request.clone(),
561 )?;
562 let annotated_response = response_codec
563 .as_ref()
564 .and_then(|codec| codec.decode_response(&response).ok())
565 .map(Arc::new);
566 llm_call_end(
567 LlmCallEndParams::builder()
568 .handle(&handle)
569 .response(response.clone())
570 .data_opt(data)
571 .metadata_opt(metadata)
572 .annotated_response_opt(annotated_response)
573 .build(),
574 )?;
575 Ok(response)
576 }
577 Err(error) => {
578 let _ = emit_llm_start_once(
579 &start_emitted,
580 &handle,
581 &fallback_request,
582 annotated_request,
583 );
584 let _ = emit_llm_end_without_output(&handle, metadata);
585 Err(error)
586 }
587 }
588}
589
590pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Result<LlmJsonStream> {
625 let LlmStreamCallExecuteParams {
626 name,
627 request,
628 func,
629 collector,
630 finalizer,
631 parent,
632 attributes,
633 data,
634 metadata,
635 model_name,
636 codec,
637 response_codec,
638 } = params;
639 ensure_runtime_owner()?;
640 {
641 let scope_stack = current_scope_stack();
642 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
643 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
644 ®istries.llm_conditional_execution_guardrails
645 });
646 let context = global_context();
647 let state = context
648 .read()
649 .map_err(|error| FlowError::Internal(error.to_string()))?;
650 if let Some(error) = state.llm_conditional_execution_chain(&request, &scope_locals)? {
651 drop(state);
652 drop(scope_guard);
653 let mut rejection_data = json!({});
654 if let Some(object) = rejection_data.as_object_mut() {
655 object.insert("rejected".into(), json!(true));
656 object.insert("rejection_reason".into(), json!(&error));
657 }
658 let _ = event(
659 EmitMarkEventParams::builder()
660 .name(&name)
661 .parent_opt(parent.as_ref())
662 .data(rejection_data)
663 .metadata_opt(metadata.clone())
664 .build(),
665 );
666 return Err(FlowError::GuardrailRejected(error));
667 }
668 }
669
670 let (intercepted_request, annotated_request) =
671 run_request_intercepts_with_codec(&name, request, codec)?;
672
673 let handle = create_llm_handle(
674 CreateLlmHandleParams::builder()
675 .name(name.as_str())
676 .parent_uuid_opt(resolve_parent_uuid(parent.as_ref()))
677 .attributes(attributes)
678 .data_opt(data.clone())
679 .metadata_opt(metadata.clone())
680 .model_name_opt(model_name)
681 .build(),
682 )?;
683 let start_emitted = Arc::new(AtomicBool::new(false));
684 let fallback_request = intercepted_request.clone();
685 let execution_handle = handle.clone();
686 let execution_annotated_request = annotated_request.clone();
687 let execution_start_emitted = start_emitted.clone();
688 let instrumented_func: LlmStreamExecutionNextFn = Arc::new(move |request| {
689 let next = func.clone();
690 let handle = execution_handle.clone();
691 let annotated_request = execution_annotated_request.clone();
692 let start_emitted = execution_start_emitted.clone();
693 Box::pin(async move {
694 emit_llm_start_once(&start_emitted, &handle, &request, annotated_request)?;
695 next(request).await
696 })
697 });
698
699 let execution = {
700 let scope_stack = current_scope_stack();
701 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
702 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
703 ®istries.llm_stream_execution_intercepts
704 });
705 let context = global_context();
706 let state = context
707 .read()
708 .map_err(|error| FlowError::Internal(error.to_string()))?;
709 state.llm_stream_build_execution_chain(&name, instrumented_func, &scope_locals)
710 };
711
712 match execution(intercepted_request).await {
713 Ok(raw_stream) => {
714 emit_llm_start_once(
715 &start_emitted,
716 &handle,
717 &fallback_request,
718 annotated_request.clone(),
719 )?;
720 let wrapper = LlmStreamWrapper::new(
721 raw_stream,
722 handle,
723 collector,
724 finalizer,
725 data,
726 metadata,
727 response_codec,
728 );
729 Ok(Box::pin(wrapper) as LlmJsonStream)
730 }
731 Err(error) => {
732 let _ = emit_llm_start_once(
733 &start_emitted,
734 &handle,
735 &fallback_request,
736 annotated_request,
737 );
738 let _ = emit_llm_end_without_output(&handle, metadata);
739 Err(error)
740 }
741 }
742}
743
744pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result<LlmRequest> {
764 ensure_runtime_owner()?;
765 let scope_stack = current_scope_stack();
766 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
767 let scope_locals =
768 scope_guard.collect_scope_local_registries(|registries| ®istries.llm_request_intercepts);
769 let context = global_context();
770 let state = context
771 .read()
772 .map_err(|error| FlowError::Internal(error.to_string()))?;
773 let (request, _) = state.llm_request_intercepts_chain(name, request, None, &scope_locals)?;
774 Ok(request)
775}
776
777pub fn llm_conditional_execution(request: &LlmRequest) -> Result<()> {
796 ensure_runtime_owner()?;
797 let scope_stack = current_scope_stack();
798 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
799 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
800 ®istries.llm_conditional_execution_guardrails
801 });
802 let context = global_context();
803 let state = context
804 .read()
805 .map_err(|error| FlowError::Internal(error.to_string()))?;
806 if let Some(error) = state.llm_conditional_execution_chain(request, &scope_locals)? {
807 return Err(FlowError::GuardrailRejected(error));
808 }
809 Ok(())
810}