1use crate::llm::transport::LlmTransportError;
2use crate::llm::types::{
3 LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
4 LlmRequest, LlmResponse, LlmRole, LlmStreamEvent, LlmTerminalReason, LlmToolChoice,
5};
6use crate::provider::ProviderHandle;
7use crate::{LashSchema, SchemaContract};
8use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
9use std::sync::Arc;
10
11#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum DirectRole {
14 System,
15 User,
16 Assistant,
17}
18
19#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
20pub enum DirectPart {
21 Text(String),
22 Image(usize),
23}
24
25#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
26pub struct DirectMessage {
27 pub role: DirectRole,
28 pub parts: Vec<DirectPart>,
29}
30
31#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
32pub struct DirectJsonSchema {
33 pub name: String,
34 pub schema: SchemaContract,
35 pub strict: bool,
36}
37
38#[derive(Clone, Debug, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
39pub enum DirectOutputSpec {
40 #[default]
41 Text,
42 JsonObject,
43 JsonSchema(DirectJsonSchema),
44}
45
46#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
47pub struct DirectRequest {
48 pub model: String,
49 #[serde(default, skip_serializing_if = "Option::is_none")]
50 pub model_variant: Option<String>,
51 #[serde(default)]
52 pub messages: Vec<DirectMessage>,
53 #[serde(default)]
54 pub attachments: Vec<LlmAttachment>,
55 #[serde(default)]
56 pub output: DirectOutputSpec,
57 #[serde(default)]
58 pub generation: crate::GenerationOptions,
59 #[serde(default, skip)]
60 pub stream_events: Option<LlmEventSender>,
61 #[serde(default, skip_serializing_if = "Option::is_none")]
62 pub session_id: Option<String>,
63 #[serde(default, skip_serializing_if = "Option::is_none")]
64 pub caused_by: Option<crate::CausalRef>,
65 #[serde(default, skip_serializing_if = "Option::is_none")]
66 pub replay: Option<crate::RuntimeReplay>,
67}
68
69impl DirectRequest {
70 pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
71 Self {
72 model: model.into(),
73 model_variant: None,
74 messages: vec![DirectMessage {
75 role: DirectRole::User,
76 parts: vec![DirectPart::Text(prompt.into())],
77 }],
78 attachments: Vec::new(),
79 output: DirectOutputSpec::Text,
80 generation: crate::GenerationOptions::default(),
81 stream_events: None,
82 session_id: None,
83 caused_by: None,
84 replay: None,
85 }
86 }
87
88 pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
89 Self {
90 output: DirectOutputSpec::JsonObject,
91 ..Self::text(model, prompt)
92 }
93 }
94
95 pub fn json_schema(
96 model: impl Into<String>,
97 prompt: impl Into<String>,
98 schema: DirectJsonSchema,
99 ) -> Self {
100 Self {
101 output: DirectOutputSpec::JsonSchema(schema),
102 ..Self::text(model, prompt)
103 }
104 }
105
106 pub fn with_replay_key(mut self, key: impl Into<String>) -> Self {
107 self.replay = Some(crate::RuntimeReplay { key: key.into() });
108 self
109 }
110
111 pub fn with_caused_by(mut self, caused_by: crate::CausalRef) -> Self {
112 self.caused_by = Some(caused_by);
113 self
114 }
115}
116
117#[derive(Debug, thiserror::Error, Clone)]
118pub enum DirectLlmError {
119 #[error("invalid request: {0}")]
120 InvalidRequest(String),
121 #[error("invalid response: {0}")]
122 InvalidResponse(String),
123 #[error("transport error: {0}")]
124 Transport(#[from] LlmTransportError),
125}
126
127pub struct DirectLlmClient {
128 provider: ProviderHandle,
129 trace_sink: Option<Arc<dyn TraceSink>>,
130 trace_context: TraceContext,
131 clock: Arc<dyn crate::Clock>,
132}
133
134impl DirectLlmClient {
135 pub fn new(provider: ProviderHandle) -> Self {
136 Self {
137 provider,
138 trace_sink: None,
139 trace_context: TraceContext::default(),
140 clock: Arc::new(crate::SystemClock),
141 }
142 }
143
144 pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
145 self.trace_sink = sink;
146 self
147 }
148
149 pub fn with_trace_context(mut self, context: TraceContext) -> Self {
150 self.trace_context = context;
151 self
152 }
153
154 pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
155 self.clock = clock;
156 self
157 }
158
159 pub fn provider(&self) -> &ProviderHandle {
160 &self.provider
161 }
162
163 pub fn provider_mut(&mut self) -> &mut ProviderHandle {
164 &mut self.provider
165 }
166
167 pub async fn complete(
168 &mut self,
169 request: DirectRequest,
170 ) -> Result<LlmResponse, DirectLlmError> {
171 if let Some(variant) = request.model_variant.as_deref() {
172 self.provider
173 .validate_variant(&request.model, variant)
174 .map_err(DirectLlmError::InvalidRequest)?;
175 }
176
177 let output_for_validation = request.output.clone();
178 let model = request.model.clone();
179 let llm_request = build_llm_request(&self.provider, request, model);
180 let llm_call_id = if self.trace_sink.is_some() {
181 let id = uuid::Uuid::new_v4().to_string();
182 crate::trace::emit_trace(
183 &self.trace_sink,
184 &self.trace_context,
185 TraceContext::default().for_llm_call(id.clone()),
186 TraceEvent::LlmCallStarted {
187 request: crate::trace::trace_llm_request(&llm_request),
188 },
189 self.clock.as_ref(),
190 );
191 Some(id)
192 } else {
193 None
194 };
195 match self.provider.complete(llm_request).await {
196 Ok(response) => {
197 if let Err(error) = validate_direct_output(&output_for_validation, &response) {
198 if let Some(llm_call_id) = llm_call_id {
199 crate::trace::emit_trace(
200 &self.trace_sink,
201 &self.trace_context,
202 TraceContext::default().for_llm_call(llm_call_id),
203 TraceEvent::LlmCallFailed {
204 error: TraceError {
205 message: error.to_string(),
206 retryable: false,
207 terminal_reason: Some(
208 LlmTerminalReason::ProviderError.code().to_string(),
209 ),
210 code: Some("invalid_structured_output".to_string()),
211 raw: None,
212 },
213 stream_summary: None,
214 },
215 self.clock.as_ref(),
216 );
217 }
218 return Err(error);
219 }
220 if let Some(llm_call_id) = llm_call_id {
221 crate::trace::emit_trace(
222 &self.trace_sink,
223 &self.trace_context,
224 TraceContext::default().for_llm_call(llm_call_id),
225 TraceEvent::LlmCallCompleted {
226 response: crate::trace::trace_llm_response(
227 response.full_text.clone(),
228 0,
229 Some(response.terminal_reason),
230 crate::trace::trace_output_parts(&response.parts),
231 ),
232 usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
233 provider_usage: response.provider_usage.clone(),
234 stream_summary: None,
235 },
236 self.clock.as_ref(),
237 );
238 }
239 Ok(response)
240 }
241 Err(error) => {
242 if let Some(llm_call_id) = llm_call_id {
243 crate::trace::emit_trace(
244 &self.trace_sink,
245 &self.trace_context,
246 TraceContext::default().for_llm_call(llm_call_id),
247 TraceEvent::LlmCallFailed {
248 error: TraceError {
249 message: error.message.clone(),
250 retryable: error.retryable,
251 terminal_reason: Some(error.terminal_reason.code().to_string()),
252 code: error.code.clone(),
253 raw: error.raw.clone(),
254 },
255 stream_summary: None,
256 },
257 self.clock.as_ref(),
258 );
259 }
260 Err(DirectLlmError::from(error))
261 }
262 }
263 }
264}
265
266pub(crate) fn build_llm_request(
267 provider: &ProviderHandle,
268 request: DirectRequest,
269 model: String,
270) -> LlmRequest {
271 let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
272 let DirectRequest {
273 model: _,
274 model_variant,
275 messages,
276 attachments,
277 output,
278 generation,
279 stream_events: _,
280 session_id,
281 caused_by: _,
282 replay: _,
283 } = request;
284
285 let output_spec = match output {
286 DirectOutputSpec::Text => None,
287 DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
288 DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
289 name: schema.name,
290 schema: schema.schema,
291 strict: schema.strict,
292 })),
293 };
294
295 let mut llm_messages = Vec::new();
296 for message in messages {
297 let role = match message.role {
298 DirectRole::System => LlmRole::System,
299 DirectRole::User => LlmRole::User,
300 DirectRole::Assistant => LlmRole::Assistant,
301 };
302 let mut blocks: Vec<LlmContentBlock> = Vec::new();
303 for part in message.parts {
304 match part {
305 DirectPart::Text(text) => {
306 if !text.is_empty() {
307 blocks.push(LlmContentBlock::Text {
308 text: text.into(),
309 response_meta: None,
310 cache_breakpoint: false,
311 });
312 }
313 }
314 DirectPart::Image(idx) => {
315 blocks.push(LlmContentBlock::Image {
316 attachment_idx: idx,
317 });
318 }
319 }
320 }
321 if !blocks.is_empty() {
322 llm_messages.push(LlmMessage::new(role, blocks));
323 }
324 }
325
326 LlmRequest {
327 model,
328 messages: llm_messages,
329 attachments,
330 tools: Vec::new().into(),
331 tool_choice: LlmToolChoice::None,
332 model_variant,
333 generation,
334 session_id,
335 output_spec,
336 stream_events,
337 provider_trace: None,
338 }
339}
340
341fn validate_direct_output(
342 output: &DirectOutputSpec,
343 response: &LlmResponse,
344) -> Result<(), DirectLlmError> {
345 let DirectOutputSpec::JsonSchema(schema) = output else {
346 return Ok(());
347 };
348 let parsed: serde_json::Value = serde_json::from_str(response.full_text.trim())
349 .map_err(|err| DirectLlmError::InvalidResponse(format!("expected JSON: {err}")))?;
350 LashSchema::new(schema.schema.canonical().clone())
351 .validate(&parsed)
352 .map_err(DirectLlmError::InvalidResponse)
353}
354
355fn transport_stream_events_for_direct(
356 provider: &ProviderHandle,
357 requested: Option<LlmEventSender>,
358) -> Option<LlmEventSender> {
359 if requested.is_some() {
360 return requested;
361 }
362 if provider.requires_streaming() {
363 Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
364 } else {
365 None
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::llm::types::{LlmOutputPart, LlmTerminalReason, LlmUsage};
373 use crate::provider::{ProviderOptions, ProviderReliability};
374 use crate::testing::TestProvider;
375 use serde_json::json;
376 use std::sync::{Arc, Mutex};
377
378 #[test]
379 fn json_schema_request_preserves_output_schema() {
380 let schema = DirectJsonSchema {
381 name: "answer_shape".to_string(),
382 schema: json!({
383 "type": "object",
384 "properties": {
385 "answer": { "type": "string" }
386 },
387 "required": ["answer"]
388 })
389 .into(),
390 strict: true,
391 };
392
393 let request = DirectRequest::json_schema("model-a", "return json", schema.clone());
394
395 assert_eq!(
396 request.output,
397 DirectOutputSpec::JsonSchema(schema),
398 "DirectRequest::json_schema must carry the requested output schema"
399 );
400 }
401
402 #[test]
403 fn direct_client_provider_accessors_expose_owned_provider_handle() {
404 let provider = TestProvider::builder()
405 .kind("direct-accessor-provider")
406 .serialize_config(|| json!({"provider": "owned"}))
407 .build()
408 .into_handle();
409 let mut client = DirectLlmClient::new(provider);
410
411 assert_eq!(client.provider().kind(), "direct-accessor-provider");
412 assert_eq!(
413 client.provider().to_spec().config,
414 json!({"provider": "owned"})
415 );
416
417 let mut options = ProviderOptions::default();
418 options.reliability = ProviderReliability::default().max_attempts(7);
419 options.max_output_tokens = Some(123);
420 client.provider_mut().set_options(options.clone());
421
422 assert_eq!(client.provider().options(), options);
423 }
424
425 #[tokio::test]
426 async fn direct_client_complete_delegates_to_provider_and_returns_response() {
427 let captured_request: Arc<Mutex<Option<LlmRequest>>> = Arc::new(Mutex::new(None));
428 let captured_for_provider = Arc::clone(&captured_request);
429 let provider = TestProvider::builder()
430 .kind("direct-complete-provider")
431 .complete(move |request| {
432 let captured_for_provider = Arc::clone(&captured_for_provider);
433 async move {
434 *captured_for_provider.lock().expect("capture lock") = Some(request);
435 Ok(LlmResponse {
436 full_text: "provider delegated response".to_string(),
437 parts: vec![LlmOutputPart::Text {
438 text: "provider delegated response".to_string(),
439 response_meta: None,
440 }],
441 usage: LlmUsage {
442 input_tokens: 11,
443 output_tokens: 3,
444 ..Default::default()
445 },
446 terminal_reason: LlmTerminalReason::Stop,
447 ..Default::default()
448 })
449 }
450 })
451 .build()
452 .into_handle();
453 let mut client = DirectLlmClient::new(provider);
454 let mut request = DirectRequest::json("direct-model", "answer as json");
455 request.session_id = Some("direct-session".to_string());
456
457 let response = client
458 .complete(request)
459 .await
460 .expect("direct completion should delegate");
461
462 assert_eq!(response.full_text, "provider delegated response");
463 let captured = captured_request
464 .lock()
465 .expect("capture lock")
466 .clone()
467 .expect("provider should receive a request");
468 assert_eq!(captured.model, "direct-model");
469 assert_eq!(captured.session_id.as_deref(), Some("direct-session"));
470 assert!(matches!(
471 captured.output_spec,
472 Some(LlmOutputSpec::JsonObject)
473 ));
474 assert_eq!(captured.messages.len(), 1);
475 }
476
477 #[tokio::test]
478 async fn direct_client_validates_json_schema_output_against_canonical_schema() {
479 let provider = TestProvider::builder()
480 .kind("direct-validation-provider")
481 .complete(|_request| async {
482 Ok(LlmResponse {
483 full_text: r#"{"items":[]}"#.to_string(),
484 terminal_reason: LlmTerminalReason::Stop,
485 ..Default::default()
486 })
487 })
488 .build()
489 .into_handle();
490 let mut client = DirectLlmClient::new(provider);
491 let request = DirectRequest::json_schema(
492 "direct-model",
493 "return items",
494 DirectJsonSchema {
495 name: "items_result".to_string(),
496 schema: json!({
497 "type": "object",
498 "required": ["items"],
499 "properties": {
500 "items": {
501 "type": "array",
502 "minItems": 1,
503 "items": { "type": "string" }
504 }
505 }
506 })
507 .into(),
508 strict: true,
509 },
510 );
511
512 let err = client
513 .complete(request)
514 .await
515 .expect_err("empty items must fail canonical validation");
516
517 assert!(matches!(err, DirectLlmError::InvalidResponse(_)));
518 assert!(err.to_string().contains("items >= 1"));
519 }
520
521 #[test]
522 fn build_llm_request_preserves_nonempty_content_and_drops_empty_messages() {
523 let provider = TestProvider::default().into_handle();
524 let request = DirectRequest {
525 model: "input-model".to_string(),
526 messages: vec![
527 DirectMessage {
528 role: DirectRole::System,
529 parts: vec![DirectPart::Text(String::new())],
530 },
531 DirectMessage {
532 role: DirectRole::User,
533 parts: vec![
534 DirectPart::Text("hello".to_string()),
535 DirectPart::Text(String::new()),
536 ],
537 },
538 DirectMessage {
539 role: DirectRole::Assistant,
540 parts: vec![DirectPart::Image(2)],
541 },
542 ],
543 attachments: Vec::new(),
544 output: DirectOutputSpec::Text,
545 generation: crate::GenerationOptions::default(),
546 stream_events: None,
547 session_id: None,
548 model_variant: None,
549 caused_by: None,
550 replay: None,
551 };
552
553 let llm_request = build_llm_request(&provider, request, "transport-model".to_string());
554
555 assert_eq!(llm_request.model, "transport-model");
556 assert_eq!(
557 llm_request.messages.len(),
558 2,
559 "empty normalized messages must be dropped"
560 );
561 assert_eq!(llm_request.messages[0].role, LlmRole::User);
562 assert_eq!(llm_request.messages[0].blocks.len(), 1);
563 assert!(matches!(
564 &llm_request.messages[0].blocks[0],
565 LlmContentBlock::Text { text, .. } if text.as_ref() == "hello"
566 ));
567 assert_eq!(llm_request.messages[1].role, LlmRole::Assistant);
568 assert!(matches!(
569 &llm_request.messages[1].blocks[0],
570 LlmContentBlock::Image { attachment_idx: 2 }
571 ));
572 }
573
574 #[test]
575 fn build_llm_request_preserves_direct_stream_sender_and_adds_required_noop_sender() {
576 let captured_events: Arc<Mutex<Vec<LlmStreamEvent>>> = Arc::new(Mutex::new(Vec::new()));
577 let captured_for_sender = Arc::clone(&captured_events);
578 let requested_sender = LlmEventSender::new(move |event| {
579 captured_for_sender
580 .lock()
581 .expect("stream event lock")
582 .push(event);
583 });
584 let mut request = DirectRequest::text("model", "prompt");
585 request.stream_events = Some(requested_sender);
586 let provider = TestProvider::default().into_handle();
587
588 let llm_request = build_llm_request(&provider, request, "model".to_string());
589 let sender = llm_request
590 .stream_events
591 .expect("explicit direct stream sender must be preserved");
592 sender.send(LlmStreamEvent::Delta("delta".to_string()));
593 assert_eq!(captured_events.lock().expect("stream event lock").len(), 1);
594
595 let streaming_provider = TestProvider::builder()
596 .requires_streaming(true)
597 .build()
598 .into_handle();
599 let llm_request = build_llm_request(
600 &streaming_provider,
601 DirectRequest::text("model", "prompt"),
602 "model".to_string(),
603 );
604 assert!(
605 llm_request.stream_events.is_some(),
606 "providers that require streaming need a no-op sender even when direct caller did not request one"
607 );
608 }
609}