1use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use bitflags::bitflags;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use serde_json::json;
11use typed_builder::TypedBuilder;
12use uuid::Uuid;
13
14use crate::api::runtime::NemoFlowContextState;
15use crate::api::runtime::current_scope_stack;
16use crate::api::runtime::global_context;
17use crate::api::runtime::{
18 LlmCollectorFn, LlmExecutionNextFn, LlmFinalizerFn, LlmJsonStream, LlmStreamExecutionNextFn,
19};
20use crate::api::scope::event;
21use crate::api::scope::{EmitMarkEventParams, ScopeHandle};
22use crate::api::shared::{
23 ensure_runtime_owner, resolve_parent_uuid, run_request_intercepts_with_codec,
24 snapshot_event_subscribers,
25};
26use crate::codec::request::AnnotatedLlmRequest;
27use crate::codec::response::AnnotatedLlmResponse;
28use crate::codec::traits::{LlmCodec, LlmResponseCodec};
29use crate::error::{FlowError, Result};
30use crate::json::Json;
31use crate::stream::LlmStreamWrapper;
32
33bitflags! {
34 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
36 pub struct LlmAttributes: u32 {
37 const STATEFUL = 0b01;
39 const STREAMING = 0b10;
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
46#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
47pub struct LlmHandle {
48 #[builder(default = Uuid::now_v7())]
50 pub uuid: Uuid,
51 #[builder(default = Utc::now())]
53 pub started_at: DateTime<Utc>,
54 #[builder(setter(into))]
56 pub name: String,
57 #[builder(default)]
59 pub data: Option<Json>,
60 #[builder(default)]
62 pub metadata: Option<Json>,
63 #[builder(default = LlmAttributes::empty())]
65 pub attributes: LlmAttributes,
66 #[builder(default)]
68 pub parent_uuid: Option<Uuid>,
69 #[builder(default, setter(into))]
71 pub model_name: Option<String>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct LlmRequest {
77 pub headers: serde_json::Map<String, Json>,
79 pub content: Json,
81}
82
83#[derive(Debug, Clone, TypedBuilder)]
85#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
86pub struct CreateLlmHandleParams<'a> {
87 pub name: &'a str,
89 #[builder(default)]
91 pub parent_uuid: Option<uuid::Uuid>,
92 #[builder(default = LlmAttributes::empty())]
94 pub attributes: LlmAttributes,
95 #[builder(default)]
97 pub data: Option<Json>,
98 #[builder(default)]
100 pub metadata: Option<Json>,
101 #[builder(default, setter(into))]
103 pub model_name: Option<String>,
104 #[builder(default)]
107 pub timestamp: Option<DateTime<Utc>>,
108}
109
110#[derive(Clone, TypedBuilder)]
112#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
113pub struct EndLlmHandleParams<'a> {
114 pub handle: &'a LlmHandle,
116 #[builder(default)]
118 pub data: Option<Json>,
119 #[builder(default)]
121 pub metadata: Option<Json>,
122 #[builder(default)]
124 pub annotated_response: Option<Arc<AnnotatedLlmResponse>>,
125 #[builder(default)]
129 pub timestamp: Option<DateTime<Utc>>,
130}
131
132#[derive(TypedBuilder)]
134#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
135pub struct LlmCallParams<'a> {
136 pub name: &'a str,
138 pub request: &'a LlmRequest,
140 #[builder(default)]
142 pub parent: Option<&'a ScopeHandle>,
143 #[builder(default = LlmAttributes::empty())]
145 pub attributes: LlmAttributes,
146 #[builder(default)]
149 pub data: Option<Json>,
150 #[builder(default)]
152 pub metadata: Option<Json>,
153 #[builder(default, setter(into))]
155 pub model_name: Option<String>,
156 #[builder(default)]
158 pub annotated_request: Option<Arc<AnnotatedLlmRequest>>,
159 #[builder(default)]
162 pub timestamp: Option<DateTime<Utc>>,
163}
164
165#[derive(TypedBuilder)]
167#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
168pub struct LlmCallExecuteParams {
169 #[builder(setter(into))]
171 pub name: String,
172 pub request: LlmRequest,
174 pub func: LlmExecutionNextFn,
176 #[builder(default)]
178 pub parent: Option<ScopeHandle>,
179 #[builder(default = LlmAttributes::empty())]
181 pub attributes: LlmAttributes,
182 #[builder(default)]
185 pub data: Option<Json>,
186 #[builder(default)]
188 pub metadata: Option<Json>,
189 #[builder(default, setter(into))]
191 pub model_name: Option<String>,
192 #[builder(default)]
194 pub codec: Option<Arc<dyn LlmCodec>>,
195 #[builder(default)]
197 pub response_codec: Option<Arc<dyn LlmResponseCodec>>,
198}
199
200#[derive(TypedBuilder)]
202#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
203pub struct LlmStreamCallExecuteParams {
204 #[builder(setter(into))]
206 pub name: String,
207 pub request: LlmRequest,
209 pub func: LlmStreamExecutionNextFn,
211 pub collector: LlmCollectorFn,
213 pub finalizer: LlmFinalizerFn,
215 #[builder(default)]
217 pub parent: Option<ScopeHandle>,
218 #[builder(default = LlmAttributes::empty())]
220 pub attributes: LlmAttributes,
221 #[builder(default)]
224 pub data: Option<Json>,
225 #[builder(default)]
227 pub metadata: Option<Json>,
228 #[builder(default, setter(into))]
230 pub model_name: Option<String>,
231 #[builder(default)]
233 pub codec: Option<Arc<dyn LlmCodec>>,
234 #[builder(default)]
236 pub response_codec: Option<Arc<dyn LlmResponseCodec>>,
237}
238
239#[derive(TypedBuilder)]
241#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
242pub struct LlmCallEndParams<'a> {
243 pub handle: &'a LlmHandle,
245 pub response: Json,
247 #[builder(default)]
250 pub data: Option<Json>,
251 #[builder(default)]
253 pub metadata: Option<Json>,
254 #[builder(default)]
256 pub annotated_response: Option<Arc<AnnotatedLlmResponse>>,
257 #[builder(default)]
259 pub response_codec: Option<Arc<dyn LlmResponseCodec>>,
260 #[builder(default)]
264 pub timestamp: Option<DateTime<Utc>>,
265}
266
267fn create_llm_handle(params: CreateLlmHandleParams<'_>) -> Result<LlmHandle> {
268 ensure_runtime_owner()?;
269 let context = global_context();
270 let state = context
271 .read()
272 .map_err(|error| FlowError::Internal(error.to_string()))?;
273 Ok(state.create_llm_handle(params))
274}
275
276fn emit_llm_start(
277 handle: &LlmHandle,
278 request: &LlmRequest,
279 annotated_request: Option<Arc<AnnotatedLlmRequest>>,
280) -> Result<()> {
281 ensure_runtime_owner()?;
282 let (event, subscribers) = {
283 let scope_stack = current_scope_stack();
284 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
285 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
286 ®istries.llm_sanitize_request_guardrails
287 });
288 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
289 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
290 let context = global_context();
291 let state = context
292 .read()
293 .map_err(|error| FlowError::Internal(error.to_string()))?;
294
295 let sanitized_request = state.llm_sanitize_request_chain(request.clone(), &scope_locals);
296 let input = serde_json::to_value(&sanitized_request).unwrap_or(Json::Null);
297 let event = state.build_llm_start_event(handle, Some(input), annotated_request);
298 (event, subscribers)
299 };
300 NemoFlowContextState::emit_event(&event, &subscribers);
301 Ok(())
302}
303
304fn emit_llm_start_once(
305 start_emitted: &Arc<AtomicBool>,
306 handle: &LlmHandle,
307 request: &LlmRequest,
308 annotated_request: Option<Arc<AnnotatedLlmRequest>>,
309) -> Result<()> {
310 if start_emitted.swap(true, Ordering::SeqCst) {
311 return Ok(());
312 }
313 emit_llm_start(handle, request, annotated_request)
314}
315
316pub fn llm_call(params: LlmCallParams<'_>) -> Result<LlmHandle> {
347 let handle_params = CreateLlmHandleParams::builder()
348 .name(params.name)
349 .parent_uuid_opt(resolve_parent_uuid(params.parent))
350 .attributes(params.attributes)
351 .data_opt(params.data)
352 .metadata_opt(params.metadata)
353 .model_name_opt(params.model_name)
354 .timestamp_opt(params.timestamp)
355 .build();
356 let handle = create_llm_handle(handle_params)?;
357 emit_llm_start(&handle, params.request, params.annotated_request)?;
358 Ok(handle)
359}
360
361pub fn llm_call_end(params: LlmCallEndParams<'_>) -> Result<()> {
393 let LlmCallEndParams {
394 handle,
395 response,
396 data,
397 metadata,
398 annotated_response,
399 response_codec,
400 timestamp,
401 } = params;
402 ensure_runtime_owner()?;
403 let mut decode_error = None;
404 let (event, subscribers) = {
405 let scope_stack = current_scope_stack();
406 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
407 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
408 ®istries.llm_sanitize_response_guardrails
409 });
410 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
411 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
412 let context = global_context();
413 let state = context
414 .read()
415 .map_err(|error| FlowError::Internal(error.to_string()))?;
416
417 let sanitized_response = state.llm_sanitize_response_chain(response, &scope_locals);
418 let data = if sanitized_response.is_null() {
419 data
420 } else {
421 Some(sanitized_response)
422 };
423 let annotated_response = match annotated_response {
424 Some(annotated_response) => Some(annotated_response),
425 None => match (response_codec.as_ref(), data.as_ref()) {
426 (Some(codec), Some(response)) => match codec.decode_response(response) {
427 Ok(decoded) => Some(Arc::new(decoded)),
428 Err(error) => {
429 decode_error = Some(error);
430 None
431 }
432 },
433 _ => None,
434 },
435 };
436 let event = state.build_llm_end_event(
437 EndLlmHandleParams::builder()
438 .handle(handle)
439 .data_opt(data)
440 .metadata_opt(metadata)
441 .annotated_response_opt(annotated_response)
442 .timestamp_opt(timestamp)
443 .build(),
444 );
445 (event, subscribers)
446 };
447 NemoFlowContextState::emit_event(&event, &subscribers);
448 if let Some(error) = decode_error {
449 Err(error)
450 } else {
451 Ok(())
452 }
453}
454
455fn emit_llm_end_without_output(handle: &LlmHandle, metadata: Option<Json>) -> Result<()> {
456 ensure_runtime_owner()?;
457 let (event, subscribers) = {
458 let scope_stack = current_scope_stack();
459 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
460 let scope_subscribers = scope_guard.collect_scope_local_subscribers();
461 let subscribers = snapshot_event_subscribers(scope_subscribers)?;
462 let context = global_context();
463 let state = context
464 .read()
465 .map_err(|error| FlowError::Internal(error.to_string()))?;
466 let event = state.end_llm_handle(handle, handle.data.clone(), metadata, None);
467 (event, subscribers)
468 };
469 NemoFlowContextState::emit_event(&event, &subscribers);
470 Ok(())
471}
472
473pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result<Json> {
507 let LlmCallExecuteParams {
508 name,
509 request,
510 func,
511 parent,
512 attributes,
513 data,
514 metadata,
515 model_name,
516 codec,
517 response_codec,
518 } = params;
519 ensure_runtime_owner()?;
520 {
521 let scope_stack = current_scope_stack();
522 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
523 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
524 ®istries.llm_conditional_execution_guardrails
525 });
526 let context = global_context();
527 let state = context
528 .read()
529 .map_err(|error| FlowError::Internal(error.to_string()))?;
530 if let Some(error) = state.llm_conditional_execution_chain(&request, &scope_locals)? {
531 drop(state);
532 drop(scope_guard);
533 let mut rejection_data = json!({});
534 if let Some(object) = rejection_data.as_object_mut() {
535 object.insert("rejected".into(), json!(true));
536 object.insert("rejection_reason".into(), json!(&error));
537 }
538 let _ = event(
539 EmitMarkEventParams::builder()
540 .name(&name)
541 .parent_opt(parent.as_ref())
542 .data(rejection_data)
543 .metadata_opt(metadata.clone())
544 .build(),
545 );
546 return Err(FlowError::GuardrailRejected(error));
547 }
548 }
549
550 let (intercepted_request, annotated_request) =
551 run_request_intercepts_with_codec(&name, request, codec)?;
552
553 let handle = create_llm_handle(
554 CreateLlmHandleParams::builder()
555 .name(name.as_str())
556 .parent_uuid_opt(resolve_parent_uuid(parent.as_ref()))
557 .attributes(attributes)
558 .data_opt(data.clone())
559 .metadata_opt(metadata.clone())
560 .model_name_opt(model_name)
561 .build(),
562 )?;
563 let start_emitted = Arc::new(AtomicBool::new(false));
564 let fallback_request = intercepted_request.clone();
565 let execution_handle = handle.clone();
566 let execution_annotated_request = annotated_request.clone();
567 let execution_start_emitted = start_emitted.clone();
568 let instrumented_func: LlmExecutionNextFn = Arc::new(move |request| {
569 let next = func.clone();
570 let handle = execution_handle.clone();
571 let annotated_request = execution_annotated_request.clone();
572 let start_emitted = execution_start_emitted.clone();
573 Box::pin(async move {
574 emit_llm_start_once(&start_emitted, &handle, &request, annotated_request)?;
575 next(request).await
576 })
577 });
578
579 let execution = {
580 let scope_stack = current_scope_stack();
581 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
582 let scope_locals = scope_guard
583 .collect_scope_local_registries(|registries| ®istries.llm_execution_intercepts);
584 let context = global_context();
585 let state = context
586 .read()
587 .map_err(|error| FlowError::Internal(error.to_string()))?;
588 state.llm_build_execution_chain(&name, instrumented_func, &scope_locals)
589 };
590
591 match execution(intercepted_request).await {
592 Ok(response) => {
593 emit_llm_start_once(
594 &start_emitted,
595 &handle,
596 &fallback_request,
597 annotated_request.clone(),
598 )?;
599 let annotated_response = response_codec
600 .as_ref()
601 .and_then(|codec| codec.decode_response(&response).ok())
602 .map(Arc::new);
603 llm_call_end(
604 LlmCallEndParams::builder()
605 .handle(&handle)
606 .response(response.clone())
607 .data_opt(data)
608 .metadata_opt(metadata)
609 .annotated_response_opt(annotated_response)
610 .build(),
611 )?;
612 Ok(response)
613 }
614 Err(error) => {
615 let _ = emit_llm_start_once(
616 &start_emitted,
617 &handle,
618 &fallback_request,
619 annotated_request,
620 );
621 let _ = emit_llm_end_without_output(&handle, metadata);
622 Err(error)
623 }
624 }
625}
626
627pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Result<LlmJsonStream> {
662 let LlmStreamCallExecuteParams {
663 name,
664 request,
665 func,
666 collector,
667 finalizer,
668 parent,
669 attributes,
670 data,
671 metadata,
672 model_name,
673 codec,
674 response_codec,
675 } = params;
676 ensure_runtime_owner()?;
677 {
678 let scope_stack = current_scope_stack();
679 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
680 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
681 ®istries.llm_conditional_execution_guardrails
682 });
683 let context = global_context();
684 let state = context
685 .read()
686 .map_err(|error| FlowError::Internal(error.to_string()))?;
687 if let Some(error) = state.llm_conditional_execution_chain(&request, &scope_locals)? {
688 drop(state);
689 drop(scope_guard);
690 let mut rejection_data = json!({});
691 if let Some(object) = rejection_data.as_object_mut() {
692 object.insert("rejected".into(), json!(true));
693 object.insert("rejection_reason".into(), json!(&error));
694 }
695 let _ = event(
696 EmitMarkEventParams::builder()
697 .name(&name)
698 .parent_opt(parent.as_ref())
699 .data(rejection_data)
700 .metadata_opt(metadata.clone())
701 .build(),
702 );
703 return Err(FlowError::GuardrailRejected(error));
704 }
705 }
706
707 let (intercepted_request, annotated_request) =
708 run_request_intercepts_with_codec(&name, request, codec)?;
709
710 let handle = create_llm_handle(
711 CreateLlmHandleParams::builder()
712 .name(name.as_str())
713 .parent_uuid_opt(resolve_parent_uuid(parent.as_ref()))
714 .attributes(attributes)
715 .data_opt(data.clone())
716 .metadata_opt(metadata.clone())
717 .model_name_opt(model_name)
718 .build(),
719 )?;
720 let start_emitted = Arc::new(AtomicBool::new(false));
721 let fallback_request = intercepted_request.clone();
722 let execution_handle = handle.clone();
723 let execution_annotated_request = annotated_request.clone();
724 let execution_start_emitted = start_emitted.clone();
725 let instrumented_func: LlmStreamExecutionNextFn = Arc::new(move |request| {
726 let next = func.clone();
727 let handle = execution_handle.clone();
728 let annotated_request = execution_annotated_request.clone();
729 let start_emitted = execution_start_emitted.clone();
730 Box::pin(async move {
731 emit_llm_start_once(&start_emitted, &handle, &request, annotated_request)?;
732 next(request).await
733 })
734 });
735
736 let execution = {
737 let scope_stack = current_scope_stack();
738 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
739 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
740 ®istries.llm_stream_execution_intercepts
741 });
742 let context = global_context();
743 let state = context
744 .read()
745 .map_err(|error| FlowError::Internal(error.to_string()))?;
746 state.llm_stream_build_execution_chain(&name, instrumented_func, &scope_locals)
747 };
748
749 match execution(intercepted_request).await {
750 Ok(raw_stream) => {
751 emit_llm_start_once(
752 &start_emitted,
753 &handle,
754 &fallback_request,
755 annotated_request.clone(),
756 )?;
757 let wrapper = LlmStreamWrapper::new(
758 raw_stream,
759 handle,
760 collector,
761 finalizer,
762 data,
763 metadata,
764 response_codec,
765 );
766 Ok(Box::pin(wrapper) as LlmJsonStream)
767 }
768 Err(error) => {
769 let _ = emit_llm_start_once(
770 &start_emitted,
771 &handle,
772 &fallback_request,
773 annotated_request,
774 );
775 let _ = emit_llm_end_without_output(&handle, metadata);
776 Err(error)
777 }
778 }
779}
780
781pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result<LlmRequest> {
801 ensure_runtime_owner()?;
802 let scope_stack = current_scope_stack();
803 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
804 let scope_locals =
805 scope_guard.collect_scope_local_registries(|registries| ®istries.llm_request_intercepts);
806 let context = global_context();
807 let state = context
808 .read()
809 .map_err(|error| FlowError::Internal(error.to_string()))?;
810 let (request, _) = state.llm_request_intercepts_chain(name, request, None, &scope_locals)?;
811 Ok(request)
812}
813
814pub fn llm_conditional_execution(request: &LlmRequest) -> Result<()> {
833 ensure_runtime_owner()?;
834 let scope_stack = current_scope_stack();
835 let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
836 let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
837 ®istries.llm_conditional_execution_guardrails
838 });
839 let context = global_context();
840 let state = context
841 .read()
842 .map_err(|error| FlowError::Internal(error.to_string()))?;
843 if let Some(error) = state.llm_conditional_execution_chain(request, &scope_locals)? {
844 return Err(FlowError::GuardrailRejected(error));
845 }
846 Ok(())
847}