1use crate::llm::transport::LlmTransportError;
2use crate::llm::types::{
3 LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
4 LlmRequest, LlmRequestScope, LlmResponse, LlmRole, LlmStreamEvent, LlmTerminalReason,
5 LlmToolChoice,
6};
7use crate::provider::ProviderHandle;
8use crate::{LashSchema, SchemaContract};
9use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
10use std::sync::Arc;
11
12#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum DirectRole {
15 System,
16 User,
17 Assistant,
18}
19
20#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum DirectPart {
22 Text(String),
23 Image(usize),
24}
25
26#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
27pub struct DirectMessage {
28 pub role: DirectRole,
29 pub parts: Vec<DirectPart>,
30}
31
32#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
33pub struct DirectJsonSchema {
34 pub name: String,
35 pub schema: SchemaContract,
36 pub strict: bool,
37}
38
39#[derive(Clone, Debug, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
40pub enum DirectOutputSpec {
41 #[default]
42 Text,
43 JsonObject,
44 JsonSchema(DirectJsonSchema),
45}
46
47#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
48pub struct DirectRequest {
49 pub model: String,
50 #[serde(default, skip_serializing_if = "Option::is_none")]
51 pub model_variant: Option<String>,
52 #[serde(default)]
53 pub messages: Vec<DirectMessage>,
54 #[serde(default)]
55 pub attachments: Vec<LlmAttachment>,
56 #[serde(default)]
57 pub output: DirectOutputSpec,
58 #[serde(default)]
59 pub generation: crate::GenerationOptions,
60 #[serde(default, skip)]
61 pub stream_events: Option<LlmEventSender>,
62 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub session_id: Option<String>,
64 #[serde(default, skip_serializing_if = "Option::is_none")]
65 pub caused_by: Option<crate::CausalRef>,
66 #[serde(default, skip_serializing_if = "Option::is_none")]
67 pub replay: Option<crate::RuntimeReplay>,
68}
69
70impl DirectRequest {
71 pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
72 Self {
73 model: model.into(),
74 model_variant: None,
75 messages: vec![DirectMessage {
76 role: DirectRole::User,
77 parts: vec![DirectPart::Text(prompt.into())],
78 }],
79 attachments: Vec::new(),
80 output: DirectOutputSpec::Text,
81 generation: crate::GenerationOptions::default(),
82 stream_events: None,
83 session_id: None,
84 caused_by: None,
85 replay: None,
86 }
87 }
88
89 pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
90 Self {
91 output: DirectOutputSpec::JsonObject,
92 ..Self::text(model, prompt)
93 }
94 }
95
96 pub fn json_schema(
97 model: impl Into<String>,
98 prompt: impl Into<String>,
99 schema: DirectJsonSchema,
100 ) -> Self {
101 Self {
102 output: DirectOutputSpec::JsonSchema(schema),
103 ..Self::text(model, prompt)
104 }
105 }
106
107 pub fn with_replay_key(mut self, key: impl Into<String>) -> Self {
108 self.replay = Some(crate::RuntimeReplay { key: key.into() });
109 self
110 }
111
112 pub fn with_caused_by(mut self, caused_by: crate::CausalRef) -> Self {
113 self.caused_by = Some(caused_by);
114 self
115 }
116}
117
118#[derive(Debug, thiserror::Error, Clone)]
119pub enum DirectLlmError {
120 #[error("invalid request: {0}")]
121 InvalidRequest(String),
122 #[error("invalid response: {0}")]
123 InvalidResponse(String),
124 #[error("transport error: {0}")]
125 Transport(#[from] LlmTransportError),
126}
127
128pub struct DirectLlmClient {
129 provider: ProviderHandle,
130 trace_sink: Option<Arc<dyn TraceSink>>,
131 trace_context: TraceContext,
132 clock: Arc<dyn crate::Clock>,
133}
134
135impl DirectLlmClient {
136 pub fn new(provider: ProviderHandle) -> Self {
137 Self {
138 provider,
139 trace_sink: None,
140 trace_context: TraceContext::default(),
141 clock: Arc::new(crate::SystemClock),
142 }
143 }
144
145 pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
146 self.trace_sink = sink;
147 self
148 }
149
150 pub fn with_trace_context(mut self, context: TraceContext) -> Self {
151 self.trace_context = context;
152 self
153 }
154
155 pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
156 self.clock = clock;
157 self
158 }
159
160 pub fn provider(&self) -> &ProviderHandle {
161 &self.provider
162 }
163
164 pub fn provider_mut(&mut self) -> &mut ProviderHandle {
165 &mut self.provider
166 }
167
168 pub async fn complete(
169 &mut self,
170 request: DirectRequest,
171 ) -> Result<LlmResponse, DirectLlmError> {
172 if let Some(variant) = request.model_variant.as_deref() {
173 self.provider
174 .validate_variant(&request.model, variant)
175 .map_err(DirectLlmError::InvalidRequest)?;
176 }
177
178 let output_for_validation = request.output.clone();
179 let model = request.model.clone();
180 let llm_request = build_llm_request(&self.provider, request, model);
181 let llm_call_id = if self.trace_sink.is_some() {
182 let id = uuid::Uuid::new_v4().to_string();
183 crate::trace::emit_trace(
184 &self.trace_sink,
185 &self.trace_context,
186 TraceContext::default().for_llm_call(id.clone()),
187 TraceEvent::LlmCallStarted {
188 request: crate::trace::trace_llm_request(&llm_request),
189 },
190 self.clock.as_ref(),
191 );
192 Some(id)
193 } else {
194 None
195 };
196 match self.provider.complete(llm_request).await {
197 Ok(response) => {
198 if let Err(error) = validate_direct_output(&output_for_validation, &response) {
199 if let Some(llm_call_id) = llm_call_id {
200 crate::trace::emit_trace(
201 &self.trace_sink,
202 &self.trace_context,
203 TraceContext::default().for_llm_call(llm_call_id),
204 TraceEvent::LlmCallFailed {
205 error: TraceError {
206 message: error.to_string(),
207 retryable: false,
208 terminal_reason: Some(
209 LlmTerminalReason::ProviderError.code().to_string(),
210 ),
211 code: Some("invalid_structured_output".to_string()),
212 raw: None,
213 },
214 stream_summary: None,
215 },
216 self.clock.as_ref(),
217 );
218 }
219 return Err(error);
220 }
221 if let Some(llm_call_id) = llm_call_id {
222 crate::trace::emit_trace(
223 &self.trace_sink,
224 &self.trace_context,
225 TraceContext::default().for_llm_call(llm_call_id),
226 TraceEvent::LlmCallCompleted {
227 response: crate::trace::trace_llm_response(
228 response.full_text.clone(),
229 0,
230 Some(response.terminal_reason),
231 crate::trace::trace_output_parts(&response.parts),
232 ),
233 usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
234 provider_usage: response.provider_usage.clone(),
235 stream_summary: None,
236 },
237 self.clock.as_ref(),
238 );
239 }
240 Ok(response)
241 }
242 Err(error) => {
243 if let Some(llm_call_id) = llm_call_id {
244 crate::trace::emit_trace(
245 &self.trace_sink,
246 &self.trace_context,
247 TraceContext::default().for_llm_call(llm_call_id),
248 TraceEvent::LlmCallFailed {
249 error: TraceError {
250 message: error.message.clone(),
251 retryable: error.retryable,
252 terminal_reason: Some(error.terminal_reason.code().to_string()),
253 code: error.code.clone(),
254 raw: error.raw.clone(),
255 },
256 stream_summary: None,
257 },
258 self.clock.as_ref(),
259 );
260 }
261 Err(DirectLlmError::from(error))
262 }
263 }
264 }
265}
266
267pub(crate) fn build_llm_request(
268 provider: &ProviderHandle,
269 request: DirectRequest,
270 model: String,
271) -> LlmRequest {
272 let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
273 let DirectRequest {
274 model: _,
275 model_variant,
276 messages,
277 attachments,
278 output,
279 generation,
280 stream_events: _,
281 session_id,
282 caused_by: _,
283 replay: _,
284 } = request;
285
286 let output_spec = match output {
287 DirectOutputSpec::Text => None,
288 DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
289 DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
290 name: schema.name,
291 schema: schema.schema,
292 strict: schema.strict,
293 })),
294 };
295
296 let mut llm_messages = Vec::new();
297 for message in messages {
298 let role = match message.role {
299 DirectRole::System => LlmRole::System,
300 DirectRole::User => LlmRole::User,
301 DirectRole::Assistant => LlmRole::Assistant,
302 };
303 let mut blocks: Vec<LlmContentBlock> = Vec::new();
304 for part in message.parts {
305 match part {
306 DirectPart::Text(text) => {
307 if !text.is_empty() {
308 blocks.push(LlmContentBlock::Text {
309 text: text.into(),
310 response_meta: None,
311 cache_breakpoint: false,
312 });
313 }
314 }
315 DirectPart::Image(idx) => {
316 blocks.push(LlmContentBlock::Image {
317 attachment_idx: idx,
318 });
319 }
320 }
321 }
322 if !blocks.is_empty() {
323 llm_messages.push(LlmMessage::new(role, blocks));
324 }
325 }
326
327 let scope = match session_id {
328 Some(session_id) => LlmRequestScope::new(
329 session_id.clone(),
330 format!("{session_id}:frame:direct"),
331 format!("{session_id}:direct"),
332 ),
333 None => {
334 let request_id = uuid::Uuid::new_v4().to_string();
335 LlmRequestScope::new(
336 format!("direct:{request_id}"),
337 format!("direct:{request_id}:frame"),
338 request_id,
339 )
340 }
341 };
342
343 LlmRequest {
344 model,
345 messages: llm_messages,
346 attachments,
347 tools: Vec::new().into(),
348 tool_choice: LlmToolChoice::None,
349 model_variant,
350 generation,
351 scope,
352 output_spec,
353 stream_events,
354 provider_trace: None,
355 }
356}
357
358fn validate_direct_output(
359 output: &DirectOutputSpec,
360 response: &LlmResponse,
361) -> Result<(), DirectLlmError> {
362 let DirectOutputSpec::JsonSchema(schema) = output else {
363 return Ok(());
364 };
365 let parsed: serde_json::Value = serde_json::from_str(response.full_text.trim())
366 .map_err(|err| DirectLlmError::InvalidResponse(format!("expected JSON: {err}")))?;
367 LashSchema::new(schema.schema.canonical().clone())
368 .validate(&parsed)
369 .map_err(DirectLlmError::InvalidResponse)
370}
371
372fn transport_stream_events_for_direct(
373 provider: &ProviderHandle,
374 requested: Option<LlmEventSender>,
375) -> Option<LlmEventSender> {
376 if requested.is_some() {
377 return requested;
378 }
379 if provider.requires_streaming() {
380 Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
381 } else {
382 None
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use crate::llm::types::{LlmOutputPart, LlmTerminalReason, LlmUsage};
390 use crate::provider::{ProviderOptions, ProviderReliability};
391 use crate::testing::TestProvider;
392 use serde_json::json;
393 use std::sync::{Arc, Mutex};
394
395 #[test]
396 fn json_schema_request_preserves_output_schema() {
397 let schema = DirectJsonSchema {
398 name: "answer_shape".to_string(),
399 schema: json!({
400 "type": "object",
401 "properties": {
402 "answer": { "type": "string" }
403 },
404 "required": ["answer"]
405 })
406 .into(),
407 strict: true,
408 };
409
410 let request = DirectRequest::json_schema("model-a", "return json", schema.clone());
411
412 assert_eq!(
413 request.output,
414 DirectOutputSpec::JsonSchema(schema),
415 "DirectRequest::json_schema must carry the requested output schema"
416 );
417 }
418
419 #[test]
420 fn direct_client_provider_accessors_expose_owned_provider_handle() {
421 let provider = TestProvider::builder()
422 .kind("direct-accessor-provider")
423 .serialize_config(|| json!({"provider": "owned"}))
424 .build()
425 .into_handle();
426 let mut client = DirectLlmClient::new(provider);
427
428 assert_eq!(client.provider().kind(), "direct-accessor-provider");
429 assert_eq!(
430 client.provider().to_spec().config,
431 json!({"provider": "owned"})
432 );
433
434 let mut options = ProviderOptions::default();
435 options.reliability = ProviderReliability::default().max_attempts(7);
436 options.max_output_tokens = Some(123);
437 client.provider_mut().set_options(options.clone());
438
439 assert_eq!(client.provider().options(), options);
440 }
441
442 #[tokio::test]
443 async fn direct_client_complete_delegates_to_provider_and_returns_response() {
444 let captured_request: Arc<Mutex<Option<LlmRequest>>> = Arc::new(Mutex::new(None));
445 let captured_for_provider = Arc::clone(&captured_request);
446 let provider = TestProvider::builder()
447 .kind("direct-complete-provider")
448 .complete(move |request| {
449 let captured_for_provider = Arc::clone(&captured_for_provider);
450 async move {
451 *captured_for_provider.lock().expect("capture lock") = Some(request);
452 Ok(LlmResponse {
453 full_text: "provider delegated response".to_string(),
454 parts: vec![LlmOutputPart::Text {
455 text: "provider delegated response".to_string(),
456 response_meta: None,
457 }],
458 usage: LlmUsage {
459 input_tokens: 11,
460 output_tokens: 3,
461 ..Default::default()
462 },
463 terminal_reason: LlmTerminalReason::Stop,
464 ..Default::default()
465 })
466 }
467 })
468 .build()
469 .into_handle();
470 let mut client = DirectLlmClient::new(provider);
471 let mut request = DirectRequest::json("direct-model", "answer as json");
472 request.session_id = Some("direct-session".to_string());
473
474 let response = client
475 .complete(request)
476 .await
477 .expect("direct completion should delegate");
478
479 assert_eq!(response.full_text, "provider delegated response");
480 let captured = captured_request
481 .lock()
482 .expect("capture lock")
483 .clone()
484 .expect("provider should receive a request");
485 assert_eq!(captured.model, "direct-model");
486 assert_eq!(captured.scope.session_id, "direct-session");
487 assert_eq!(captured.scope.agent_frame_id, "direct-session:frame:direct");
488 assert_eq!(captured.scope.request_id, "direct-session:direct");
489 assert!(matches!(
490 captured.output_spec,
491 Some(LlmOutputSpec::JsonObject)
492 ));
493 assert_eq!(captured.messages.len(), 1);
494 }
495
496 #[tokio::test]
497 async fn direct_client_validates_json_schema_output_against_canonical_schema() {
498 let provider = TestProvider::builder()
499 .kind("direct-validation-provider")
500 .complete(|_request| async {
501 Ok(LlmResponse {
502 full_text: r#"{"items":[]}"#.to_string(),
503 terminal_reason: LlmTerminalReason::Stop,
504 ..Default::default()
505 })
506 })
507 .build()
508 .into_handle();
509 let mut client = DirectLlmClient::new(provider);
510 let request = DirectRequest::json_schema(
511 "direct-model",
512 "return items",
513 DirectJsonSchema {
514 name: "items_result".to_string(),
515 schema: json!({
516 "type": "object",
517 "required": ["items"],
518 "properties": {
519 "items": {
520 "type": "array",
521 "minItems": 1,
522 "items": { "type": "string" }
523 }
524 }
525 })
526 .into(),
527 strict: true,
528 },
529 );
530
531 let err = client
532 .complete(request)
533 .await
534 .expect_err("empty items must fail canonical validation");
535
536 assert!(matches!(err, DirectLlmError::InvalidResponse(_)));
537 assert!(err.to_string().contains("items >= 1"));
538 }
539
540 #[test]
541 fn build_llm_request_preserves_nonempty_content_and_drops_empty_messages() {
542 let provider = TestProvider::default().into_handle();
543 let request = DirectRequest {
544 model: "input-model".to_string(),
545 messages: vec![
546 DirectMessage {
547 role: DirectRole::System,
548 parts: vec![DirectPart::Text(String::new())],
549 },
550 DirectMessage {
551 role: DirectRole::User,
552 parts: vec![
553 DirectPart::Text("hello".to_string()),
554 DirectPart::Text(String::new()),
555 ],
556 },
557 DirectMessage {
558 role: DirectRole::Assistant,
559 parts: vec![DirectPart::Image(2)],
560 },
561 ],
562 attachments: Vec::new(),
563 output: DirectOutputSpec::Text,
564 generation: crate::GenerationOptions::default(),
565 stream_events: None,
566 session_id: None,
567 model_variant: None,
568 caused_by: None,
569 replay: None,
570 };
571
572 let llm_request = build_llm_request(&provider, request, "transport-model".to_string());
573
574 assert_eq!(llm_request.model, "transport-model");
575 assert_eq!(
576 llm_request.messages.len(),
577 2,
578 "empty normalized messages must be dropped"
579 );
580 assert_eq!(llm_request.messages[0].role, LlmRole::User);
581 assert_eq!(llm_request.messages[0].blocks.len(), 1);
582 assert!(matches!(
583 &llm_request.messages[0].blocks[0],
584 LlmContentBlock::Text { text, .. } if text.as_ref() == "hello"
585 ));
586 assert_eq!(llm_request.messages[1].role, LlmRole::Assistant);
587 assert!(matches!(
588 &llm_request.messages[1].blocks[0],
589 LlmContentBlock::Image { attachment_idx: 2 }
590 ));
591 }
592
593 #[test]
594 fn build_llm_request_preserves_direct_stream_sender_and_adds_required_noop_sender() {
595 let captured_events: Arc<Mutex<Vec<LlmStreamEvent>>> = Arc::new(Mutex::new(Vec::new()));
596 let captured_for_sender = Arc::clone(&captured_events);
597 let requested_sender = LlmEventSender::new(move |event| {
598 captured_for_sender
599 .lock()
600 .expect("stream event lock")
601 .push(event);
602 });
603 let mut request = DirectRequest::text("model", "prompt");
604 request.stream_events = Some(requested_sender);
605 let provider = TestProvider::default().into_handle();
606
607 let llm_request = build_llm_request(&provider, request, "model".to_string());
608 let sender = llm_request
609 .stream_events
610 .expect("explicit direct stream sender must be preserved");
611 sender.send(LlmStreamEvent::Delta("delta".to_string()));
612 assert_eq!(captured_events.lock().expect("stream event lock").len(), 1);
613
614 let streaming_provider = TestProvider::builder()
615 .requires_streaming(true)
616 .build()
617 .into_handle();
618 let llm_request = build_llm_request(
619 &streaming_provider,
620 DirectRequest::text("model", "prompt"),
621 "model".to_string(),
622 );
623 assert!(
624 llm_request.stream_events.is_some(),
625 "providers that require streaming need a no-op sender even when direct caller did not request one"
626 );
627 }
628}