use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use bitflags::bitflags;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::json;
use typed_builder::TypedBuilder;
use uuid::Uuid;
use crate::api::runtime::NemoFlowContextState;
use crate::api::runtime::current_scope_stack;
use crate::api::runtime::global_context;
use crate::api::runtime::{
LlmCollectorFn, LlmExecutionNextFn, LlmFinalizerFn, LlmJsonStream, LlmStreamExecutionNextFn,
};
use crate::api::scope::event;
use crate::api::scope::{EmitMarkEventParams, ScopeHandle};
use crate::api::shared::{
ensure_runtime_owner, resolve_parent_uuid, run_request_intercepts_with_codec,
snapshot_event_subscribers,
};
use crate::codec::request::AnnotatedLlmRequest;
use crate::codec::response::AnnotatedLlmResponse;
use crate::codec::traits::{LlmCodec, LlmResponseCodec};
use crate::error::{FlowError, Result};
use crate::json::Json;
use crate::stream::LlmStreamWrapper;
bitflags! {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct LlmAttributes: u32 {
const STATEFUL = 0b01;
const STREAMING = 0b10;
}
}
#[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct LlmHandle {
#[builder(default = Uuid::now_v7())]
pub uuid: Uuid,
#[builder(default = Utc::now())]
pub started_at: DateTime<Utc>,
#[builder(setter(into))]
pub name: String,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default = LlmAttributes::empty())]
pub attributes: LlmAttributes,
#[builder(default)]
pub parent_uuid: Option<Uuid>,
#[builder(default, setter(into))]
pub model_name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmRequest {
pub headers: serde_json::Map<String, Json>,
pub content: Json,
}
#[derive(Debug, Clone, TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct CreateLlmHandleParams<'a> {
pub name: &'a str,
#[builder(default)]
pub parent_uuid: Option<uuid::Uuid>,
#[builder(default = LlmAttributes::empty())]
pub attributes: LlmAttributes,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default, setter(into))]
pub model_name: Option<String>,
#[builder(default)]
pub timestamp: Option<DateTime<Utc>>,
}
#[derive(Clone, TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct EndLlmHandleParams<'a> {
pub handle: &'a LlmHandle,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default)]
pub annotated_response: Option<Arc<AnnotatedLlmResponse>>,
#[builder(default)]
pub timestamp: Option<DateTime<Utc>>,
}
#[derive(TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct LlmCallParams<'a> {
pub name: &'a str,
pub request: &'a LlmRequest,
#[builder(default)]
pub parent: Option<&'a ScopeHandle>,
#[builder(default = LlmAttributes::empty())]
pub attributes: LlmAttributes,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default, setter(into))]
pub model_name: Option<String>,
#[builder(default)]
pub annotated_request: Option<Arc<AnnotatedLlmRequest>>,
#[builder(default)]
pub timestamp: Option<DateTime<Utc>>,
}
#[derive(TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct LlmCallExecuteParams {
#[builder(setter(into))]
pub name: String,
pub request: LlmRequest,
pub func: LlmExecutionNextFn,
#[builder(default)]
pub parent: Option<ScopeHandle>,
#[builder(default = LlmAttributes::empty())]
pub attributes: LlmAttributes,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default, setter(into))]
pub model_name: Option<String>,
#[builder(default)]
pub codec: Option<Arc<dyn LlmCodec>>,
#[builder(default)]
pub response_codec: Option<Arc<dyn LlmResponseCodec>>,
}
#[derive(TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct LlmStreamCallExecuteParams {
#[builder(setter(into))]
pub name: String,
pub request: LlmRequest,
pub func: LlmStreamExecutionNextFn,
pub collector: LlmCollectorFn,
pub finalizer: LlmFinalizerFn,
#[builder(default)]
pub parent: Option<ScopeHandle>,
#[builder(default = LlmAttributes::empty())]
pub attributes: LlmAttributes,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default, setter(into))]
pub model_name: Option<String>,
#[builder(default)]
pub codec: Option<Arc<dyn LlmCodec>>,
#[builder(default)]
pub response_codec: Option<Arc<dyn LlmResponseCodec>>,
}
#[derive(TypedBuilder)]
#[builder(field_defaults(setter(strip_option(ignore_invalid, fallback_suffix = "_opt"))))]
pub struct LlmCallEndParams<'a> {
pub handle: &'a LlmHandle,
pub response: Json,
#[builder(default)]
pub data: Option<Json>,
#[builder(default)]
pub metadata: Option<Json>,
#[builder(default)]
pub annotated_response: Option<Arc<AnnotatedLlmResponse>>,
#[builder(default)]
pub response_codec: Option<Arc<dyn LlmResponseCodec>>,
#[builder(default)]
pub timestamp: Option<DateTime<Utc>>,
}
fn create_llm_handle(params: CreateLlmHandleParams<'_>) -> Result<LlmHandle> {
ensure_runtime_owner()?;
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
Ok(state.create_llm_handle(params))
}
fn emit_llm_start(
handle: &LlmHandle,
request: &LlmRequest,
annotated_request: Option<Arc<AnnotatedLlmRequest>>,
) -> Result<()> {
ensure_runtime_owner()?;
let (event, subscribers) = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
®istries.llm_sanitize_request_guardrails
});
let scope_subscribers = scope_guard.collect_scope_local_subscribers();
let subscribers = snapshot_event_subscribers(scope_subscribers)?;
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
let sanitized_request = state.llm_sanitize_request_chain(request.clone(), &scope_locals);
let input = serde_json::to_value(&sanitized_request).unwrap_or(Json::Null);
let event = state.build_llm_start_event(handle, Some(input), annotated_request);
(event, subscribers)
};
NemoFlowContextState::emit_event(&event, &subscribers);
Ok(())
}
fn emit_llm_start_once(
start_emitted: &Arc<AtomicBool>,
handle: &LlmHandle,
request: &LlmRequest,
annotated_request: Option<Arc<AnnotatedLlmRequest>>,
) -> Result<()> {
if start_emitted.swap(true, Ordering::SeqCst) {
return Ok(());
}
emit_llm_start(handle, request, annotated_request)
}
pub fn llm_call(params: LlmCallParams<'_>) -> Result<LlmHandle> {
let handle_params = CreateLlmHandleParams::builder()
.name(params.name)
.parent_uuid_opt(resolve_parent_uuid(params.parent))
.attributes(params.attributes)
.data_opt(params.data)
.metadata_opt(params.metadata)
.model_name_opt(params.model_name)
.timestamp_opt(params.timestamp)
.build();
let handle = create_llm_handle(handle_params)?;
emit_llm_start(&handle, params.request, params.annotated_request)?;
Ok(handle)
}
pub fn llm_call_end(params: LlmCallEndParams<'_>) -> Result<()> {
let LlmCallEndParams {
handle,
response,
data,
metadata,
annotated_response,
response_codec,
timestamp,
} = params;
ensure_runtime_owner()?;
let mut decode_error = None;
let (event, subscribers) = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
®istries.llm_sanitize_response_guardrails
});
let scope_subscribers = scope_guard.collect_scope_local_subscribers();
let subscribers = snapshot_event_subscribers(scope_subscribers)?;
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
let sanitized_response = state.llm_sanitize_response_chain(response, &scope_locals);
let data = if sanitized_response.is_null() {
data
} else {
Some(sanitized_response)
};
let annotated_response = match annotated_response {
Some(annotated_response) => Some(annotated_response),
None => match (response_codec.as_ref(), data.as_ref()) {
(Some(codec), Some(response)) => match codec.decode_response(response) {
Ok(decoded) => Some(Arc::new(decoded)),
Err(error) => {
decode_error = Some(error);
None
}
},
_ => None,
},
};
let event = state.build_llm_end_event(
EndLlmHandleParams::builder()
.handle(handle)
.data_opt(data)
.metadata_opt(metadata)
.annotated_response_opt(annotated_response)
.timestamp_opt(timestamp)
.build(),
);
(event, subscribers)
};
NemoFlowContextState::emit_event(&event, &subscribers);
if let Some(error) = decode_error {
Err(error)
} else {
Ok(())
}
}
fn emit_llm_end_without_output(handle: &LlmHandle, metadata: Option<Json>) -> Result<()> {
ensure_runtime_owner()?;
let (event, subscribers) = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_subscribers = scope_guard.collect_scope_local_subscribers();
let subscribers = snapshot_event_subscribers(scope_subscribers)?;
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
let event = state.end_llm_handle(handle, handle.data.clone(), metadata, None);
(event, subscribers)
};
NemoFlowContextState::emit_event(&event, &subscribers);
Ok(())
}
pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result<Json> {
let LlmCallExecuteParams {
name,
request,
func,
parent,
attributes,
data,
metadata,
model_name,
codec,
response_codec,
} = params;
ensure_runtime_owner()?;
{
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
®istries.llm_conditional_execution_guardrails
});
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
if let Some(error) = state.llm_conditional_execution_chain(&request, &scope_locals)? {
drop(state);
drop(scope_guard);
let mut rejection_data = json!({});
if let Some(object) = rejection_data.as_object_mut() {
object.insert("rejected".into(), json!(true));
object.insert("rejection_reason".into(), json!(&error));
}
let _ = event(
EmitMarkEventParams::builder()
.name(&name)
.parent_opt(parent.as_ref())
.data(rejection_data)
.metadata_opt(metadata.clone())
.build(),
);
return Err(FlowError::GuardrailRejected(error));
}
}
let (intercepted_request, annotated_request) =
run_request_intercepts_with_codec(&name, request, codec)?;
let handle = create_llm_handle(
CreateLlmHandleParams::builder()
.name(name.as_str())
.parent_uuid_opt(resolve_parent_uuid(parent.as_ref()))
.attributes(attributes)
.data_opt(data.clone())
.metadata_opt(metadata.clone())
.model_name_opt(model_name)
.build(),
)?;
let start_emitted = Arc::new(AtomicBool::new(false));
let fallback_request = intercepted_request.clone();
let execution_handle = handle.clone();
let execution_annotated_request = annotated_request.clone();
let execution_start_emitted = start_emitted.clone();
let instrumented_func: LlmExecutionNextFn = Arc::new(move |request| {
let next = func.clone();
let handle = execution_handle.clone();
let annotated_request = execution_annotated_request.clone();
let start_emitted = execution_start_emitted.clone();
Box::pin(async move {
emit_llm_start_once(&start_emitted, &handle, &request, annotated_request)?;
next(request).await
})
});
let execution = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard
.collect_scope_local_registries(|registries| ®istries.llm_execution_intercepts);
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
state.llm_build_execution_chain(&name, instrumented_func, &scope_locals)
};
match execution(intercepted_request).await {
Ok(response) => {
emit_llm_start_once(
&start_emitted,
&handle,
&fallback_request,
annotated_request.clone(),
)?;
let annotated_response = response_codec
.as_ref()
.and_then(|codec| codec.decode_response(&response).ok())
.map(Arc::new);
llm_call_end(
LlmCallEndParams::builder()
.handle(&handle)
.response(response.clone())
.data_opt(data)
.metadata_opt(metadata)
.annotated_response_opt(annotated_response)
.build(),
)?;
Ok(response)
}
Err(error) => {
let _ = emit_llm_start_once(
&start_emitted,
&handle,
&fallback_request,
annotated_request,
);
let _ = emit_llm_end_without_output(&handle, metadata);
Err(error)
}
}
}
pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Result<LlmJsonStream> {
let LlmStreamCallExecuteParams {
name,
request,
func,
collector,
finalizer,
parent,
attributes,
data,
metadata,
model_name,
codec,
response_codec,
} = params;
ensure_runtime_owner()?;
{
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
®istries.llm_conditional_execution_guardrails
});
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
if let Some(error) = state.llm_conditional_execution_chain(&request, &scope_locals)? {
drop(state);
drop(scope_guard);
let mut rejection_data = json!({});
if let Some(object) = rejection_data.as_object_mut() {
object.insert("rejected".into(), json!(true));
object.insert("rejection_reason".into(), json!(&error));
}
let _ = event(
EmitMarkEventParams::builder()
.name(&name)
.parent_opt(parent.as_ref())
.data(rejection_data)
.metadata_opt(metadata.clone())
.build(),
);
return Err(FlowError::GuardrailRejected(error));
}
}
let (intercepted_request, annotated_request) =
run_request_intercepts_with_codec(&name, request, codec)?;
let handle = create_llm_handle(
CreateLlmHandleParams::builder()
.name(name.as_str())
.parent_uuid_opt(resolve_parent_uuid(parent.as_ref()))
.attributes(attributes)
.data_opt(data.clone())
.metadata_opt(metadata.clone())
.model_name_opt(model_name)
.build(),
)?;
let start_emitted = Arc::new(AtomicBool::new(false));
let fallback_request = intercepted_request.clone();
let execution_handle = handle.clone();
let execution_annotated_request = annotated_request.clone();
let execution_start_emitted = start_emitted.clone();
let instrumented_func: LlmStreamExecutionNextFn = Arc::new(move |request| {
let next = func.clone();
let handle = execution_handle.clone();
let annotated_request = execution_annotated_request.clone();
let start_emitted = execution_start_emitted.clone();
Box::pin(async move {
emit_llm_start_once(&start_emitted, &handle, &request, annotated_request)?;
next(request).await
})
});
let execution = {
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
®istries.llm_stream_execution_intercepts
});
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
state.llm_stream_build_execution_chain(&name, instrumented_func, &scope_locals)
};
match execution(intercepted_request).await {
Ok(raw_stream) => {
emit_llm_start_once(
&start_emitted,
&handle,
&fallback_request,
annotated_request.clone(),
)?;
let wrapper = LlmStreamWrapper::new(
raw_stream,
handle,
collector,
finalizer,
data,
metadata,
response_codec,
);
Ok(Box::pin(wrapper) as LlmJsonStream)
}
Err(error) => {
let _ = emit_llm_start_once(
&start_emitted,
&handle,
&fallback_request,
annotated_request,
);
let _ = emit_llm_end_without_output(&handle, metadata);
Err(error)
}
}
}
pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result<LlmRequest> {
ensure_runtime_owner()?;
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals =
scope_guard.collect_scope_local_registries(|registries| ®istries.llm_request_intercepts);
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
let (request, _) = state.llm_request_intercepts_chain(name, request, None, &scope_locals)?;
Ok(request)
}
pub fn llm_conditional_execution(request: &LlmRequest) -> Result<()> {
ensure_runtime_owner()?;
let scope_stack = current_scope_stack();
let scope_guard = scope_stack.read().expect("scope stack lock poisoned");
let scope_locals = scope_guard.collect_scope_local_registries(|registries| {
®istries.llm_conditional_execution_guardrails
});
let context = global_context();
let state = context
.read()
.map_err(|error| FlowError::Internal(error.to_string()))?;
if let Some(error) = state.llm_conditional_execution_chain(request, &scope_locals)? {
return Err(FlowError::GuardrailRejected(error));
}
Ok(())
}