Skip to main content

lash_core/
direct.rs

1use crate::llm::transport::LlmTransportError;
2use crate::llm::types::{
3    LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
4    LlmRequest, LlmResponse, LlmRole, LlmStreamEvent, LlmTerminalReason, LlmToolChoice,
5};
6use crate::provider::ProviderHandle;
7use crate::{LashSchema, SchemaContract};
8use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
9use std::sync::Arc;
10
11#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum DirectRole {
14    System,
15    User,
16    Assistant,
17}
18
19#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
20pub enum DirectPart {
21    Text(String),
22    Image(usize),
23}
24
25#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
26pub struct DirectMessage {
27    pub role: DirectRole,
28    pub parts: Vec<DirectPart>,
29}
30
31#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
32pub struct DirectJsonSchema {
33    pub name: String,
34    pub schema: SchemaContract,
35    pub strict: bool,
36}
37
38#[derive(Clone, Debug, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
39pub enum DirectOutputSpec {
40    #[default]
41    Text,
42    JsonObject,
43    JsonSchema(DirectJsonSchema),
44}
45
46#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
47pub struct DirectRequest {
48    pub model: String,
49    #[serde(default, skip_serializing_if = "Option::is_none")]
50    pub model_variant: Option<String>,
51    #[serde(default)]
52    pub messages: Vec<DirectMessage>,
53    #[serde(default)]
54    pub attachments: Vec<LlmAttachment>,
55    #[serde(default)]
56    pub output: DirectOutputSpec,
57    #[serde(default)]
58    pub generation: crate::GenerationOptions,
59    #[serde(default, skip)]
60    pub stream_events: Option<LlmEventSender>,
61    #[serde(default, skip_serializing_if = "Option::is_none")]
62    pub session_id: Option<String>,
63    #[serde(default, skip_serializing_if = "Option::is_none")]
64    pub caused_by: Option<crate::CausalRef>,
65    #[serde(default, skip_serializing_if = "Option::is_none")]
66    pub replay: Option<crate::RuntimeReplay>,
67}
68
69impl DirectRequest {
70    pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
71        Self {
72            model: model.into(),
73            model_variant: None,
74            messages: vec![DirectMessage {
75                role: DirectRole::User,
76                parts: vec![DirectPart::Text(prompt.into())],
77            }],
78            attachments: Vec::new(),
79            output: DirectOutputSpec::Text,
80            generation: crate::GenerationOptions::default(),
81            stream_events: None,
82            session_id: None,
83            caused_by: None,
84            replay: None,
85        }
86    }
87
88    pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
89        Self {
90            output: DirectOutputSpec::JsonObject,
91            ..Self::text(model, prompt)
92        }
93    }
94
95    pub fn json_schema(
96        model: impl Into<String>,
97        prompt: impl Into<String>,
98        schema: DirectJsonSchema,
99    ) -> Self {
100        Self {
101            output: DirectOutputSpec::JsonSchema(schema),
102            ..Self::text(model, prompt)
103        }
104    }
105
106    pub fn with_replay_key(mut self, key: impl Into<String>) -> Self {
107        self.replay = Some(crate::RuntimeReplay { key: key.into() });
108        self
109    }
110
111    pub fn with_caused_by(mut self, caused_by: crate::CausalRef) -> Self {
112        self.caused_by = Some(caused_by);
113        self
114    }
115}
116
117#[derive(Debug, thiserror::Error, Clone)]
118pub enum DirectLlmError {
119    #[error("invalid request: {0}")]
120    InvalidRequest(String),
121    #[error("invalid response: {0}")]
122    InvalidResponse(String),
123    #[error("transport error: {0}")]
124    Transport(#[from] LlmTransportError),
125}
126
127pub struct DirectLlmClient {
128    provider: ProviderHandle,
129    trace_sink: Option<Arc<dyn TraceSink>>,
130    trace_context: TraceContext,
131    clock: Arc<dyn crate::Clock>,
132}
133
134impl DirectLlmClient {
135    pub fn new(provider: ProviderHandle) -> Self {
136        Self {
137            provider,
138            trace_sink: None,
139            trace_context: TraceContext::default(),
140            clock: Arc::new(crate::SystemClock),
141        }
142    }
143
144    pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
145        self.trace_sink = sink;
146        self
147    }
148
149    pub fn with_trace_context(mut self, context: TraceContext) -> Self {
150        self.trace_context = context;
151        self
152    }
153
154    pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
155        self.clock = clock;
156        self
157    }
158
159    pub fn provider(&self) -> &ProviderHandle {
160        &self.provider
161    }
162
163    pub fn provider_mut(&mut self) -> &mut ProviderHandle {
164        &mut self.provider
165    }
166
167    pub async fn complete(
168        &mut self,
169        request: DirectRequest,
170    ) -> Result<LlmResponse, DirectLlmError> {
171        if let Some(variant) = request.model_variant.as_deref() {
172            self.provider
173                .validate_variant(&request.model, variant)
174                .map_err(DirectLlmError::InvalidRequest)?;
175        }
176
177        let output_for_validation = request.output.clone();
178        let model = request.model.clone();
179        let llm_request = build_llm_request(&self.provider, request, model);
180        let llm_call_id = if self.trace_sink.is_some() {
181            let id = uuid::Uuid::new_v4().to_string();
182            crate::trace::emit_trace(
183                &self.trace_sink,
184                &self.trace_context,
185                TraceContext::default().for_llm_call(id.clone()),
186                TraceEvent::LlmCallStarted {
187                    request: crate::trace::trace_llm_request(&llm_request),
188                },
189                self.clock.as_ref(),
190            );
191            Some(id)
192        } else {
193            None
194        };
195        match self.provider.complete(llm_request).await {
196            Ok(response) => {
197                if let Err(error) = validate_direct_output(&output_for_validation, &response) {
198                    if let Some(llm_call_id) = llm_call_id {
199                        crate::trace::emit_trace(
200                            &self.trace_sink,
201                            &self.trace_context,
202                            TraceContext::default().for_llm_call(llm_call_id),
203                            TraceEvent::LlmCallFailed {
204                                error: TraceError {
205                                    message: error.to_string(),
206                                    retryable: false,
207                                    terminal_reason: Some(
208                                        LlmTerminalReason::ProviderError.code().to_string(),
209                                    ),
210                                    code: Some("invalid_structured_output".to_string()),
211                                    raw: None,
212                                },
213                                stream_summary: None,
214                            },
215                            self.clock.as_ref(),
216                        );
217                    }
218                    return Err(error);
219                }
220                if let Some(llm_call_id) = llm_call_id {
221                    crate::trace::emit_trace(
222                        &self.trace_sink,
223                        &self.trace_context,
224                        TraceContext::default().for_llm_call(llm_call_id),
225                        TraceEvent::LlmCallCompleted {
226                            response: crate::trace::trace_llm_response(
227                                response.full_text.clone(),
228                                0,
229                                Some(response.terminal_reason),
230                                crate::trace::trace_output_parts(&response.parts),
231                            ),
232                            usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
233                            provider_usage: response.provider_usage.clone(),
234                            stream_summary: None,
235                        },
236                        self.clock.as_ref(),
237                    );
238                }
239                Ok(response)
240            }
241            Err(error) => {
242                if let Some(llm_call_id) = llm_call_id {
243                    crate::trace::emit_trace(
244                        &self.trace_sink,
245                        &self.trace_context,
246                        TraceContext::default().for_llm_call(llm_call_id),
247                        TraceEvent::LlmCallFailed {
248                            error: TraceError {
249                                message: error.message.clone(),
250                                retryable: error.retryable,
251                                terminal_reason: Some(error.terminal_reason.code().to_string()),
252                                code: error.code.clone(),
253                                raw: error.raw.clone(),
254                            },
255                            stream_summary: None,
256                        },
257                        self.clock.as_ref(),
258                    );
259                }
260                Err(DirectLlmError::from(error))
261            }
262        }
263    }
264}
265
266pub(crate) fn build_llm_request(
267    provider: &ProviderHandle,
268    request: DirectRequest,
269    model: String,
270) -> LlmRequest {
271    let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
272    let DirectRequest {
273        model: _,
274        model_variant,
275        messages,
276        attachments,
277        output,
278        generation,
279        stream_events: _,
280        session_id,
281        caused_by: _,
282        replay: _,
283    } = request;
284
285    let output_spec = match output {
286        DirectOutputSpec::Text => None,
287        DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
288        DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
289            name: schema.name,
290            schema: schema.schema,
291            strict: schema.strict,
292        })),
293    };
294
295    let mut llm_messages = Vec::new();
296    for message in messages {
297        let role = match message.role {
298            DirectRole::System => LlmRole::System,
299            DirectRole::User => LlmRole::User,
300            DirectRole::Assistant => LlmRole::Assistant,
301        };
302        let mut blocks: Vec<LlmContentBlock> = Vec::new();
303        for part in message.parts {
304            match part {
305                DirectPart::Text(text) => {
306                    if !text.is_empty() {
307                        blocks.push(LlmContentBlock::Text {
308                            text: text.into(),
309                            response_meta: None,
310                            cache_breakpoint: false,
311                        });
312                    }
313                }
314                DirectPart::Image(idx) => {
315                    blocks.push(LlmContentBlock::Image {
316                        attachment_idx: idx,
317                    });
318                }
319            }
320        }
321        if !blocks.is_empty() {
322            llm_messages.push(LlmMessage::new(role, blocks));
323        }
324    }
325
326    LlmRequest {
327        model,
328        messages: llm_messages,
329        attachments,
330        tools: Vec::new().into(),
331        tool_choice: LlmToolChoice::None,
332        model_variant,
333        generation,
334        session_id,
335        output_spec,
336        stream_events,
337        provider_trace: None,
338    }
339}
340
341fn validate_direct_output(
342    output: &DirectOutputSpec,
343    response: &LlmResponse,
344) -> Result<(), DirectLlmError> {
345    let DirectOutputSpec::JsonSchema(schema) = output else {
346        return Ok(());
347    };
348    let parsed: serde_json::Value = serde_json::from_str(response.full_text.trim())
349        .map_err(|err| DirectLlmError::InvalidResponse(format!("expected JSON: {err}")))?;
350    LashSchema::new(schema.schema.canonical().clone())
351        .validate(&parsed)
352        .map_err(DirectLlmError::InvalidResponse)
353}
354
355fn transport_stream_events_for_direct(
356    provider: &ProviderHandle,
357    requested: Option<LlmEventSender>,
358) -> Option<LlmEventSender> {
359    if requested.is_some() {
360        return requested;
361    }
362    if provider.requires_streaming() {
363        Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
364    } else {
365        None
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::llm::types::{LlmOutputPart, LlmTerminalReason, LlmUsage};
373    use crate::provider::{ProviderOptions, ProviderReliability};
374    use crate::testing::TestProvider;
375    use serde_json::json;
376    use std::sync::{Arc, Mutex};
377
378    #[test]
379    fn json_schema_request_preserves_output_schema() {
380        let schema = DirectJsonSchema {
381            name: "answer_shape".to_string(),
382            schema: json!({
383                "type": "object",
384                "properties": {
385                    "answer": { "type": "string" }
386                },
387                "required": ["answer"]
388            })
389            .into(),
390            strict: true,
391        };
392
393        let request = DirectRequest::json_schema("model-a", "return json", schema.clone());
394
395        assert_eq!(
396            request.output,
397            DirectOutputSpec::JsonSchema(schema),
398            "DirectRequest::json_schema must carry the requested output schema"
399        );
400    }
401
402    #[test]
403    fn direct_client_provider_accessors_expose_owned_provider_handle() {
404        let provider = TestProvider::builder()
405            .kind("direct-accessor-provider")
406            .serialize_config(|| json!({"provider": "owned"}))
407            .build()
408            .into_handle();
409        let mut client = DirectLlmClient::new(provider);
410
411        assert_eq!(client.provider().kind(), "direct-accessor-provider");
412        assert_eq!(
413            client.provider().to_spec().config,
414            json!({"provider": "owned"})
415        );
416
417        let mut options = ProviderOptions::default();
418        options.reliability = ProviderReliability::default().max_attempts(7);
419        options.max_output_tokens = Some(123);
420        client.provider_mut().set_options(options.clone());
421
422        assert_eq!(client.provider().options(), options);
423    }
424
425    #[tokio::test]
426    async fn direct_client_complete_delegates_to_provider_and_returns_response() {
427        let captured_request: Arc<Mutex<Option<LlmRequest>>> = Arc::new(Mutex::new(None));
428        let captured_for_provider = Arc::clone(&captured_request);
429        let provider = TestProvider::builder()
430            .kind("direct-complete-provider")
431            .complete(move |request| {
432                let captured_for_provider = Arc::clone(&captured_for_provider);
433                async move {
434                    *captured_for_provider.lock().expect("capture lock") = Some(request);
435                    Ok(LlmResponse {
436                        full_text: "provider delegated response".to_string(),
437                        parts: vec![LlmOutputPart::Text {
438                            text: "provider delegated response".to_string(),
439                            response_meta: None,
440                        }],
441                        usage: LlmUsage {
442                            input_tokens: 11,
443                            output_tokens: 3,
444                            ..Default::default()
445                        },
446                        terminal_reason: LlmTerminalReason::Stop,
447                        ..Default::default()
448                    })
449                }
450            })
451            .build()
452            .into_handle();
453        let mut client = DirectLlmClient::new(provider);
454        let mut request = DirectRequest::json("direct-model", "answer as json");
455        request.session_id = Some("direct-session".to_string());
456
457        let response = client
458            .complete(request)
459            .await
460            .expect("direct completion should delegate");
461
462        assert_eq!(response.full_text, "provider delegated response");
463        let captured = captured_request
464            .lock()
465            .expect("capture lock")
466            .clone()
467            .expect("provider should receive a request");
468        assert_eq!(captured.model, "direct-model");
469        assert_eq!(captured.session_id.as_deref(), Some("direct-session"));
470        assert!(matches!(
471            captured.output_spec,
472            Some(LlmOutputSpec::JsonObject)
473        ));
474        assert_eq!(captured.messages.len(), 1);
475    }
476
477    #[tokio::test]
478    async fn direct_client_validates_json_schema_output_against_canonical_schema() {
479        let provider = TestProvider::builder()
480            .kind("direct-validation-provider")
481            .complete(|_request| async {
482                Ok(LlmResponse {
483                    full_text: r#"{"items":[]}"#.to_string(),
484                    terminal_reason: LlmTerminalReason::Stop,
485                    ..Default::default()
486                })
487            })
488            .build()
489            .into_handle();
490        let mut client = DirectLlmClient::new(provider);
491        let request = DirectRequest::json_schema(
492            "direct-model",
493            "return items",
494            DirectJsonSchema {
495                name: "items_result".to_string(),
496                schema: json!({
497                    "type": "object",
498                    "required": ["items"],
499                    "properties": {
500                        "items": {
501                            "type": "array",
502                            "minItems": 1,
503                            "items": { "type": "string" }
504                        }
505                    }
506                })
507                .into(),
508                strict: true,
509            },
510        );
511
512        let err = client
513            .complete(request)
514            .await
515            .expect_err("empty items must fail canonical validation");
516
517        assert!(matches!(err, DirectLlmError::InvalidResponse(_)));
518        assert!(err.to_string().contains("items >= 1"));
519    }
520
521    #[test]
522    fn build_llm_request_preserves_nonempty_content_and_drops_empty_messages() {
523        let provider = TestProvider::default().into_handle();
524        let request = DirectRequest {
525            model: "input-model".to_string(),
526            messages: vec![
527                DirectMessage {
528                    role: DirectRole::System,
529                    parts: vec![DirectPart::Text(String::new())],
530                },
531                DirectMessage {
532                    role: DirectRole::User,
533                    parts: vec![
534                        DirectPart::Text("hello".to_string()),
535                        DirectPart::Text(String::new()),
536                    ],
537                },
538                DirectMessage {
539                    role: DirectRole::Assistant,
540                    parts: vec![DirectPart::Image(2)],
541                },
542            ],
543            attachments: Vec::new(),
544            output: DirectOutputSpec::Text,
545            generation: crate::GenerationOptions::default(),
546            stream_events: None,
547            session_id: None,
548            model_variant: None,
549            caused_by: None,
550            replay: None,
551        };
552
553        let llm_request = build_llm_request(&provider, request, "transport-model".to_string());
554
555        assert_eq!(llm_request.model, "transport-model");
556        assert_eq!(
557            llm_request.messages.len(),
558            2,
559            "empty normalized messages must be dropped"
560        );
561        assert_eq!(llm_request.messages[0].role, LlmRole::User);
562        assert_eq!(llm_request.messages[0].blocks.len(), 1);
563        assert!(matches!(
564            &llm_request.messages[0].blocks[0],
565            LlmContentBlock::Text { text, .. } if text.as_ref() == "hello"
566        ));
567        assert_eq!(llm_request.messages[1].role, LlmRole::Assistant);
568        assert!(matches!(
569            &llm_request.messages[1].blocks[0],
570            LlmContentBlock::Image { attachment_idx: 2 }
571        ));
572    }
573
574    #[test]
575    fn build_llm_request_preserves_direct_stream_sender_and_adds_required_noop_sender() {
576        let captured_events: Arc<Mutex<Vec<LlmStreamEvent>>> = Arc::new(Mutex::new(Vec::new()));
577        let captured_for_sender = Arc::clone(&captured_events);
578        let requested_sender = LlmEventSender::new(move |event| {
579            captured_for_sender
580                .lock()
581                .expect("stream event lock")
582                .push(event);
583        });
584        let mut request = DirectRequest::text("model", "prompt");
585        request.stream_events = Some(requested_sender);
586        let provider = TestProvider::default().into_handle();
587
588        let llm_request = build_llm_request(&provider, request, "model".to_string());
589        let sender = llm_request
590            .stream_events
591            .expect("explicit direct stream sender must be preserved");
592        sender.send(LlmStreamEvent::Delta("delta".to_string()));
593        assert_eq!(captured_events.lock().expect("stream event lock").len(), 1);
594
595        let streaming_provider = TestProvider::builder()
596            .requires_streaming(true)
597            .build()
598            .into_handle();
599        let llm_request = build_llm_request(
600            &streaming_provider,
601            DirectRequest::text("model", "prompt"),
602            "model".to_string(),
603        );
604        assert!(
605            llm_request.stream_events.is_some(),
606            "providers that require streaming need a no-op sender even when direct caller did not request one"
607        );
608    }
609}