use crate::llm::transport::LlmTransportError;
use crate::llm::types::{
LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
LlmRequest, LlmRequestScope, LlmResponse, LlmRole, LlmStreamEvent, LlmTerminalReason,
LlmToolChoice,
};
use crate::provider::ProviderHandle;
use crate::{LashSchema, SchemaContract};
use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
use std::sync::Arc;
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DirectRole {
System,
User,
Assistant,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum DirectPart {
Text(String),
Image(usize),
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct DirectMessage {
pub role: DirectRole,
pub parts: Vec<DirectPart>,
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct DirectJsonSchema {
pub name: String,
pub schema: SchemaContract,
pub strict: bool,
}
#[derive(Clone, Debug, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
pub enum DirectOutputSpec {
#[default]
Text,
JsonObject,
JsonSchema(DirectJsonSchema),
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DirectRequest {
pub model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model_variant: Option<String>,
#[serde(default)]
pub messages: Vec<DirectMessage>,
#[serde(default)]
pub attachments: Vec<LlmAttachment>,
#[serde(default)]
pub output: DirectOutputSpec,
#[serde(default)]
pub generation: crate::GenerationOptions,
#[serde(default, skip)]
pub stream_events: Option<LlmEventSender>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub caused_by: Option<crate::CausalRef>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub replay: Option<crate::RuntimeReplay>,
}
impl DirectRequest {
pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
Self {
model: model.into(),
model_variant: None,
messages: vec![DirectMessage {
role: DirectRole::User,
parts: vec![DirectPart::Text(prompt.into())],
}],
attachments: Vec::new(),
output: DirectOutputSpec::Text,
generation: crate::GenerationOptions::default(),
stream_events: None,
session_id: None,
caused_by: None,
replay: None,
}
}
pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
Self {
output: DirectOutputSpec::JsonObject,
..Self::text(model, prompt)
}
}
pub fn json_schema(
model: impl Into<String>,
prompt: impl Into<String>,
schema: DirectJsonSchema,
) -> Self {
Self {
output: DirectOutputSpec::JsonSchema(schema),
..Self::text(model, prompt)
}
}
pub fn with_replay_key(mut self, key: impl Into<String>) -> Self {
self.replay = Some(crate::RuntimeReplay { key: key.into() });
self
}
pub fn with_caused_by(mut self, caused_by: crate::CausalRef) -> Self {
self.caused_by = Some(caused_by);
self
}
}
#[derive(Debug, thiserror::Error, Clone)]
pub enum DirectLlmError {
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("invalid response: {0}")]
InvalidResponse(String),
#[error("transport error: {0}")]
Transport(#[from] Box<LlmTransportError>),
}
pub struct DirectLlmClient {
provider: ProviderHandle,
trace_sink: Option<Arc<dyn TraceSink>>,
trace_context: TraceContext,
clock: Arc<dyn crate::Clock>,
}
impl DirectLlmClient {
pub fn new(provider: ProviderHandle) -> Self {
Self {
provider,
trace_sink: None,
trace_context: TraceContext::default(),
clock: Arc::new(crate::SystemClock),
}
}
pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
self.trace_sink = sink;
self
}
pub fn with_trace_context(mut self, context: TraceContext) -> Self {
self.trace_context = context;
self
}
pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
self.clock = clock;
self
}
pub fn provider(&self) -> &ProviderHandle {
&self.provider
}
pub fn provider_mut(&mut self) -> &mut ProviderHandle {
&mut self.provider
}
pub async fn complete(
&mut self,
request: DirectRequest,
) -> Result<LlmResponse, DirectLlmError> {
if let Some(variant) = request.model_variant.as_deref() {
self.provider
.validate_variant(&request.model, variant)
.map_err(DirectLlmError::InvalidRequest)?;
}
let output_for_validation = request.output.clone();
let model = request.model.clone();
let llm_request = build_llm_request(&self.provider, request, model);
let llm_call_id = if self.trace_sink.is_some() {
let id = uuid::Uuid::new_v4().to_string();
crate::trace::emit_trace(
&self.trace_sink,
&self.trace_context,
TraceContext::default().for_llm_call(id.clone()),
TraceEvent::LlmCallStarted {
request: crate::trace::trace_llm_request(&llm_request),
},
self.clock.as_ref(),
);
Some(id)
} else {
None
};
match self.provider.complete(llm_request).await {
Ok(response) => {
if let Err(error) = validate_direct_output(&output_for_validation, &response) {
if let Some(llm_call_id) = llm_call_id {
crate::trace::emit_trace(
&self.trace_sink,
&self.trace_context,
TraceContext::default().for_llm_call(llm_call_id),
TraceEvent::LlmCallFailed {
error: TraceError {
message: error.to_string(),
retryable: false,
terminal_reason: Some(
LlmTerminalReason::ProviderError.code().to_string(),
),
code: Some("invalid_structured_output".to_string()),
raw: None,
},
stream_summary: None,
},
self.clock.as_ref(),
);
}
return Err(error);
}
if let Some(llm_call_id) = llm_call_id {
crate::trace::emit_trace(
&self.trace_sink,
&self.trace_context,
TraceContext::default().for_llm_call(llm_call_id),
TraceEvent::LlmCallCompleted {
response: crate::trace::trace_llm_response(
response.full_text.clone(),
0,
Some(response.terminal_reason),
crate::trace::trace_output_parts(&response.parts),
),
usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
provider_usage: response.provider_usage.clone(),
stream_summary: None,
},
self.clock.as_ref(),
);
}
Ok(response)
}
Err(error) => {
if let Some(llm_call_id) = llm_call_id {
crate::trace::emit_trace(
&self.trace_sink,
&self.trace_context,
TraceContext::default().for_llm_call(llm_call_id),
TraceEvent::LlmCallFailed {
error: TraceError {
message: error.message.clone(),
retryable: error.retryable,
terminal_reason: Some(error.terminal_reason.code().to_string()),
code: error.code.clone(),
raw: error.raw.clone(),
},
stream_summary: None,
},
self.clock.as_ref(),
);
}
Err(DirectLlmError::from(Box::new(error)))
}
}
}
}
pub(crate) fn build_llm_request(
provider: &ProviderHandle,
request: DirectRequest,
model: String,
) -> LlmRequest {
let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
let DirectRequest {
model: _,
model_variant,
messages,
attachments,
output,
generation,
stream_events: _,
session_id,
caused_by: _,
replay: _,
} = request;
let output_spec = match output {
DirectOutputSpec::Text => None,
DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
name: schema.name,
schema: schema.schema,
strict: schema.strict,
})),
};
let mut llm_messages = Vec::new();
for message in messages {
let role = match message.role {
DirectRole::System => LlmRole::System,
DirectRole::User => LlmRole::User,
DirectRole::Assistant => LlmRole::Assistant,
};
let mut blocks: Vec<LlmContentBlock> = Vec::new();
for part in message.parts {
match part {
DirectPart::Text(text) => {
if !text.is_empty() {
blocks.push(LlmContentBlock::Text {
text: text.into(),
response_meta: None,
cache_breakpoint: false,
});
}
}
DirectPart::Image(idx) => {
blocks.push(LlmContentBlock::Image {
attachment_idx: idx,
});
}
}
}
if !blocks.is_empty() {
llm_messages.push(LlmMessage::new(role, blocks));
}
}
let scope = match session_id {
Some(session_id) => LlmRequestScope::new(
session_id.clone(),
format!("{session_id}:frame:direct"),
format!("{session_id}:direct"),
),
None => {
let request_id = uuid::Uuid::new_v4().to_string();
LlmRequestScope::new(
format!("direct:{request_id}"),
format!("direct:{request_id}:frame"),
request_id,
)
}
};
LlmRequest {
model,
messages: llm_messages,
attachments,
tools: Vec::new().into(),
tool_choice: LlmToolChoice::None,
model_variant,
generation,
scope,
output_spec,
stream_events,
provider_trace: None,
}
}
fn validate_direct_output(
output: &DirectOutputSpec,
response: &LlmResponse,
) -> Result<(), DirectLlmError> {
let DirectOutputSpec::JsonSchema(schema) = output else {
return Ok(());
};
let parsed: serde_json::Value = serde_json::from_str(response.full_text.trim())
.map_err(|err| DirectLlmError::InvalidResponse(format!("expected JSON: {err}")))?;
LashSchema::new(schema.schema.canonical().clone())
.validate(&parsed)
.map_err(DirectLlmError::InvalidResponse)
}
fn transport_stream_events_for_direct(
provider: &ProviderHandle,
requested: Option<LlmEventSender>,
) -> Option<LlmEventSender> {
if requested.is_some() {
return requested;
}
if provider.requires_streaming() {
Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::types::{LlmOutputPart, LlmTerminalReason, LlmUsage};
use crate::provider::{ProviderOptions, ProviderReliability};
use crate::testing::TestProvider;
use serde_json::json;
use std::sync::{Arc, Mutex};
#[test]
fn json_schema_request_preserves_output_schema() {
let schema = DirectJsonSchema {
name: "answer_shape".to_string(),
schema: json!({
"type": "object",
"properties": {
"answer": { "type": "string" }
},
"required": ["answer"]
})
.into(),
strict: true,
};
let request = DirectRequest::json_schema("model-a", "return json", schema.clone());
assert_eq!(
request.output,
DirectOutputSpec::JsonSchema(schema),
"DirectRequest::json_schema must carry the requested output schema"
);
}
#[test]
fn direct_client_provider_accessors_expose_owned_provider_handle() {
let provider = TestProvider::builder()
.kind("direct-accessor-provider")
.serialize_config(|| json!({"provider": "owned"}))
.build()
.into_handle();
let mut client = DirectLlmClient::new(provider);
assert_eq!(client.provider().kind(), "direct-accessor-provider");
assert_eq!(
client.provider().to_spec().config,
json!({"provider": "owned"})
);
let options = ProviderOptions {
reliability: ProviderReliability::default().max_attempts(7),
max_output_tokens: Some(123),
..Default::default()
};
client.provider_mut().set_options(options.clone());
assert_eq!(client.provider().options(), options);
}
#[tokio::test]
async fn direct_client_complete_delegates_to_provider_and_returns_response() {
let captured_request: Arc<Mutex<Option<LlmRequest>>> = Arc::new(Mutex::new(None));
let captured_for_provider = Arc::clone(&captured_request);
let provider = TestProvider::builder()
.kind("direct-complete-provider")
.complete(move |request| {
let captured_for_provider = Arc::clone(&captured_for_provider);
async move {
*captured_for_provider.lock().expect("capture lock") = Some(request);
Ok(LlmResponse {
full_text: "provider delegated response".to_string(),
parts: vec![LlmOutputPart::Text {
text: "provider delegated response".to_string(),
response_meta: None,
}],
usage: LlmUsage {
input_tokens: 11,
output_tokens: 3,
..Default::default()
},
terminal_reason: LlmTerminalReason::Stop,
..Default::default()
})
}
})
.build()
.into_handle();
let mut client = DirectLlmClient::new(provider);
let mut request = DirectRequest::json("direct-model", "answer as json");
request.session_id = Some("direct-session".to_string());
let response = client
.complete(request)
.await
.expect("direct completion should delegate");
assert_eq!(response.full_text, "provider delegated response");
let captured = captured_request
.lock()
.expect("capture lock")
.clone()
.expect("provider should receive a request");
assert_eq!(captured.model, "direct-model");
assert_eq!(captured.scope.session_id, "direct-session");
assert_eq!(captured.scope.agent_frame_id, "direct-session:frame:direct");
assert_eq!(captured.scope.request_id, "direct-session:direct");
assert!(matches!(
captured.output_spec,
Some(LlmOutputSpec::JsonObject)
));
assert_eq!(captured.messages.len(), 1);
}
#[tokio::test]
async fn direct_client_validates_json_schema_output_against_canonical_schema() {
let provider = TestProvider::builder()
.kind("direct-validation-provider")
.complete(|_request| async {
Ok(LlmResponse {
full_text: r#"{"items":[]}"#.to_string(),
terminal_reason: LlmTerminalReason::Stop,
..Default::default()
})
})
.build()
.into_handle();
let mut client = DirectLlmClient::new(provider);
let request = DirectRequest::json_schema(
"direct-model",
"return items",
DirectJsonSchema {
name: "items_result".to_string(),
schema: json!({
"type": "object",
"required": ["items"],
"properties": {
"items": {
"type": "array",
"minItems": 1,
"items": { "type": "string" }
}
}
})
.into(),
strict: true,
},
);
let err = client
.complete(request)
.await
.expect_err("empty items must fail canonical validation");
assert!(matches!(err, DirectLlmError::InvalidResponse(_)));
assert!(err.to_string().contains("items >= 1"));
}
#[test]
fn build_llm_request_preserves_nonempty_content_and_drops_empty_messages() {
let provider = TestProvider::default().into_handle();
let request = DirectRequest {
model: "input-model".to_string(),
messages: vec![
DirectMessage {
role: DirectRole::System,
parts: vec![DirectPart::Text(String::new())],
},
DirectMessage {
role: DirectRole::User,
parts: vec![
DirectPart::Text("hello".to_string()),
DirectPart::Text(String::new()),
],
},
DirectMessage {
role: DirectRole::Assistant,
parts: vec![DirectPart::Image(2)],
},
],
attachments: Vec::new(),
output: DirectOutputSpec::Text,
generation: crate::GenerationOptions::default(),
stream_events: None,
session_id: None,
model_variant: None,
caused_by: None,
replay: None,
};
let llm_request = build_llm_request(&provider, request, "transport-model".to_string());
assert_eq!(llm_request.model, "transport-model");
assert_eq!(
llm_request.messages.len(),
2,
"empty normalized messages must be dropped"
);
assert_eq!(llm_request.messages[0].role, LlmRole::User);
assert_eq!(llm_request.messages[0].blocks.len(), 1);
assert!(matches!(
&llm_request.messages[0].blocks[0],
LlmContentBlock::Text { text, .. } if text.as_ref() == "hello"
));
assert_eq!(llm_request.messages[1].role, LlmRole::Assistant);
assert!(matches!(
&llm_request.messages[1].blocks[0],
LlmContentBlock::Image { attachment_idx: 2 }
));
}
#[test]
fn build_llm_request_preserves_direct_stream_sender_and_adds_required_noop_sender() {
let captured_events: Arc<Mutex<Vec<LlmStreamEvent>>> = Arc::new(Mutex::new(Vec::new()));
let captured_for_sender = Arc::clone(&captured_events);
let requested_sender = LlmEventSender::new(move |event| {
captured_for_sender
.lock()
.expect("stream event lock")
.push(event);
});
let mut request = DirectRequest::text("model", "prompt");
request.stream_events = Some(requested_sender);
let provider = TestProvider::default().into_handle();
let llm_request = build_llm_request(&provider, request, "model".to_string());
let sender = llm_request
.stream_events
.expect("explicit direct stream sender must be preserved");
sender.send(LlmStreamEvent::Delta("delta".to_string()));
assert_eq!(captured_events.lock().expect("stream event lock").len(), 1);
let streaming_provider = TestProvider::builder()
.requires_streaming(true)
.build()
.into_handle();
let llm_request = build_llm_request(
&streaming_provider,
DirectRequest::text("model", "prompt"),
"model".to_string(),
);
assert!(
llm_request.stream_events.is_some(),
"providers that require streaming need a no-op sender even when direct caller did not request one"
);
}
}