use std::rc::Rc;
use crate::llm::trace::{emit_agent_event, AgentTraceEvent};
use crate::stdlib::json_stream::{JsonStreamStatus, StreamSchemaValidator};
use crate::value::{ErrorCategory, VmError, VmValue};
use super::options::LlmRequestPayload;
pub(crate) struct StreamSchemaWatch {
validator: StreamSchemaValidator,
provider: String,
model: String,
chunks_consumed: usize,
fired: bool,
}
#[derive(Clone, Debug)]
pub(crate) struct SchemaStreamAbort {
pub provider: String,
pub model: String,
pub reason: String,
pub path: String,
pub chunks_consumed: usize,
}
impl StreamSchemaWatch {
pub(crate) fn from_payload(payload: &LlmRequestPayload) -> Option<Self> {
if !payload.schema_stream_abort {
return None;
}
let schema = payload.output_schema.as_ref()?;
match StreamSchemaValidator::from_json_schema(schema) {
Ok(validator) => Some(Self {
validator,
provider: payload.provider.clone(),
model: payload.model.clone(),
chunks_consumed: 0,
fired: false,
}),
Err(err) => {
crate::events::log_warn(
"llm",
&format!(
"schema_stream_abort: failed to canonicalize output_schema, \
continuing without mid-stream validation: {err}"
),
);
None
}
}
}
pub(crate) fn observe(&mut self, delta: &str) -> Option<SchemaStreamAbort> {
if self.fired || delta.is_empty() {
return None;
}
self.chunks_consumed += 1;
if let JsonStreamStatus::Invalid { reason, path } = self.validator.feed(delta) {
let abort = SchemaStreamAbort {
provider: self.provider.clone(),
model: self.model.clone(),
reason: reason.clone(),
path: path.clone(),
chunks_consumed: self.chunks_consumed,
};
self.fired = true;
emit_agent_event(AgentTraceEvent::SchemaStreamAborted {
provider: abort.provider.clone(),
model: abort.model.clone(),
reason: abort.reason.clone(),
path: abort.path.clone(),
chunks_consumed: abort.chunks_consumed,
});
if let Some(metrics) = crate::active_metrics_registry() {
metrics.record_schema_stream_aborted(&abort.provider, &abort.model);
}
return Some(abort);
}
None
}
}
impl SchemaStreamAbort {
pub(crate) fn into_vm_error(self) -> VmError {
VmError::CategorizedError {
message: format!(
"schema_stream_aborted at {path}: {reason} \
(provider={provider} model={model} chunks_consumed={chunks})",
path = self.path,
reason = self.reason,
provider = self.provider,
model = self.model,
chunks = self.chunks_consumed,
),
category: ErrorCategory::SchemaStreamAborted,
}
}
}
pub(crate) fn parse_schema_stream_abort(err: &VmError) -> Option<SchemaStreamAbort> {
let VmError::CategorizedError { message, category } = err else {
return None;
};
if !matches!(category, ErrorCategory::SchemaStreamAborted) {
return None;
}
parse_abort_message(message)
}
fn parse_abort_message(message: &str) -> Option<SchemaStreamAbort> {
let body = message.strip_prefix("schema_stream_aborted at ")?;
let (path_part, rest) = body.split_once(": ")?;
let (reason_part, meta_part) = rest.rsplit_once(" (provider=")?;
let meta = meta_part.trim_end_matches(')');
let mut provider = String::new();
let mut model = String::new();
let mut chunks_consumed: usize = 0;
for entry in meta.split_whitespace() {
if let Some(value) = entry.strip_prefix("model=") {
model = value.to_string();
} else if let Some(value) = entry.strip_prefix("chunks_consumed=") {
chunks_consumed = value.parse().unwrap_or(0);
} else {
provider = entry.to_string();
}
}
Some(SchemaStreamAbort {
provider,
model,
reason: reason_part.to_string(),
path: path_part.to_string(),
chunks_consumed,
})
}
pub(crate) fn aborted_result_value(abort: &SchemaStreamAbort) -> VmValue {
let mut meta = std::collections::BTreeMap::new();
meta.insert(
"reason".to_string(),
VmValue::String(Rc::from(abort.reason.as_str())),
);
meta.insert(
"path".to_string(),
VmValue::String(Rc::from(abort.path.as_str())),
);
meta.insert(
"chunks_consumed".to_string(),
VmValue::Int(abort.chunks_consumed as i64),
);
meta.insert(
"provider".to_string(),
VmValue::String(Rc::from(abort.provider.as_str())),
);
meta.insert(
"model".to_string(),
VmValue::String(Rc::from(abort.model.as_str())),
);
let mut dict = std::collections::BTreeMap::new();
dict.insert("text".to_string(), VmValue::String(Rc::from("")));
dict.insert(
"model".to_string(),
VmValue::String(Rc::from(abort.model.as_str())),
);
dict.insert(
"provider".to_string(),
VmValue::String(Rc::from(abort.provider.as_str())),
);
dict.insert("input_tokens".to_string(), VmValue::Int(0));
dict.insert("output_tokens".to_string(), VmValue::Int(0));
dict.insert("data".to_string(), VmValue::Nil);
dict.insert(
"schema_stream_aborted".to_string(),
VmValue::Dict(Rc::new(meta)),
);
VmValue::Dict(Rc::new(dict))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_round_trip_message() {
let original = SchemaStreamAbort {
provider: "openai".to_string(),
model: "gpt-test".to_string(),
reason: "expected type 'int', got JSON string".to_string(),
path: "$.age".to_string(),
chunks_consumed: 3,
};
let err = original.clone().into_vm_error();
let parsed = parse_schema_stream_abort(&err).expect("parses");
assert_eq!(parsed.provider, original.provider);
assert_eq!(parsed.model, original.model);
assert_eq!(parsed.reason, original.reason);
assert_eq!(parsed.path, original.path);
assert_eq!(parsed.chunks_consumed, original.chunks_consumed);
}
#[test]
fn non_abort_error_is_none() {
let err = VmError::CategorizedError {
message: "something else".to_string(),
category: ErrorCategory::Timeout,
};
assert!(parse_schema_stream_abort(&err).is_none());
}
}