Skip to main content

lash_core/
direct.rs

1use crate::llm::transport::LlmTransportError;
2use crate::llm::types::{
3    LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
4    LlmRequest, LlmRequestScope, LlmResponse, LlmRole, LlmStreamEvent, LlmTerminalReason,
5    LlmToolChoice,
6};
7use crate::provider::ProviderHandle;
8use crate::{LashSchema, SchemaContract};
9use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
10use std::sync::Arc;
11
12#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
13#[serde(rename_all = "snake_case")]
14pub enum DirectRole {
15    System,
16    User,
17    Assistant,
18}
19
20#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
21pub enum DirectPart {
22    Text(String),
23    Image(usize),
24}
25
26#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
27pub struct DirectMessage {
28    pub role: DirectRole,
29    pub parts: Vec<DirectPart>,
30}
31
32#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
33pub struct DirectJsonSchema {
34    pub name: String,
35    pub schema: SchemaContract,
36    pub strict: bool,
37}
38
39#[derive(Clone, Debug, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
40pub enum DirectOutputSpec {
41    #[default]
42    Text,
43    JsonObject,
44    JsonSchema(DirectJsonSchema),
45}
46
47#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
48pub struct DirectRequest {
49    pub model: String,
50    #[serde(default, skip_serializing_if = "Option::is_none")]
51    pub model_variant: Option<String>,
52    #[serde(default)]
53    pub messages: Vec<DirectMessage>,
54    #[serde(default)]
55    pub attachments: Vec<LlmAttachment>,
56    #[serde(default)]
57    pub output: DirectOutputSpec,
58    #[serde(default)]
59    pub generation: crate::GenerationOptions,
60    #[serde(default, skip)]
61    pub stream_events: Option<LlmEventSender>,
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub session_id: Option<String>,
64    #[serde(default, skip_serializing_if = "Option::is_none")]
65    pub caused_by: Option<crate::CausalRef>,
66    #[serde(default, skip_serializing_if = "Option::is_none")]
67    pub replay: Option<crate::RuntimeReplay>,
68}
69
70impl DirectRequest {
71    pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
72        Self {
73            model: model.into(),
74            model_variant: None,
75            messages: vec![DirectMessage {
76                role: DirectRole::User,
77                parts: vec![DirectPart::Text(prompt.into())],
78            }],
79            attachments: Vec::new(),
80            output: DirectOutputSpec::Text,
81            generation: crate::GenerationOptions::default(),
82            stream_events: None,
83            session_id: None,
84            caused_by: None,
85            replay: None,
86        }
87    }
88
89    pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
90        Self {
91            output: DirectOutputSpec::JsonObject,
92            ..Self::text(model, prompt)
93        }
94    }
95
96    pub fn json_schema(
97        model: impl Into<String>,
98        prompt: impl Into<String>,
99        schema: DirectJsonSchema,
100    ) -> Self {
101        Self {
102            output: DirectOutputSpec::JsonSchema(schema),
103            ..Self::text(model, prompt)
104        }
105    }
106
107    pub fn with_replay_key(mut self, key: impl Into<String>) -> Self {
108        self.replay = Some(crate::RuntimeReplay { key: key.into() });
109        self
110    }
111
112    pub fn with_caused_by(mut self, caused_by: crate::CausalRef) -> Self {
113        self.caused_by = Some(caused_by);
114        self
115    }
116}
117
118#[derive(Debug, thiserror::Error, Clone)]
119pub enum DirectLlmError {
120    #[error("invalid request: {0}")]
121    InvalidRequest(String),
122    #[error("invalid response: {0}")]
123    InvalidResponse(String),
124    #[error("transport error: {0}")]
125    Transport(#[from] Box<LlmTransportError>),
126}
127
128pub struct DirectLlmClient {
129    provider: ProviderHandle,
130    trace_sink: Option<Arc<dyn TraceSink>>,
131    trace_context: TraceContext,
132    clock: Arc<dyn crate::Clock>,
133}
134
135impl DirectLlmClient {
136    pub fn new(provider: ProviderHandle) -> Self {
137        Self {
138            provider,
139            trace_sink: None,
140            trace_context: TraceContext::default(),
141            clock: Arc::new(crate::SystemClock),
142        }
143    }
144
145    pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
146        self.trace_sink = sink;
147        self
148    }
149
150    pub fn with_trace_context(mut self, context: TraceContext) -> Self {
151        self.trace_context = context;
152        self
153    }
154
155    pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
156        self.clock = clock;
157        self
158    }
159
160    pub fn provider(&self) -> &ProviderHandle {
161        &self.provider
162    }
163
164    pub fn provider_mut(&mut self) -> &mut ProviderHandle {
165        &mut self.provider
166    }
167
168    pub async fn complete(
169        &mut self,
170        request: DirectRequest,
171    ) -> Result<LlmResponse, DirectLlmError> {
172        if let Some(variant) = request.model_variant.as_deref() {
173            self.provider
174                .validate_variant(&request.model, variant)
175                .map_err(DirectLlmError::InvalidRequest)?;
176        }
177
178        let output_for_validation = request.output.clone();
179        let model = request.model.clone();
180        let llm_request = build_llm_request(&self.provider, request, model);
181        let llm_call_id = if self.trace_sink.is_some() {
182            let id = uuid::Uuid::new_v4().to_string();
183            crate::trace::emit_trace(
184                &self.trace_sink,
185                &self.trace_context,
186                TraceContext::default().for_llm_call(id.clone()),
187                TraceEvent::LlmCallStarted {
188                    request: crate::trace::trace_llm_request(&llm_request),
189                },
190                self.clock.as_ref(),
191            );
192            Some(id)
193        } else {
194            None
195        };
196        match self.provider.complete(llm_request).await {
197            Ok(response) => {
198                if let Err(error) = validate_direct_output(&output_for_validation, &response) {
199                    if let Some(llm_call_id) = llm_call_id {
200                        crate::trace::emit_trace(
201                            &self.trace_sink,
202                            &self.trace_context,
203                            TraceContext::default().for_llm_call(llm_call_id),
204                            TraceEvent::LlmCallFailed {
205                                error: TraceError {
206                                    message: error.to_string(),
207                                    retryable: false,
208                                    terminal_reason: Some(
209                                        LlmTerminalReason::ProviderError.code().to_string(),
210                                    ),
211                                    code: Some("invalid_structured_output".to_string()),
212                                    raw: None,
213                                },
214                                stream_summary: None,
215                            },
216                            self.clock.as_ref(),
217                        );
218                    }
219                    return Err(error);
220                }
221                if let Some(llm_call_id) = llm_call_id {
222                    crate::trace::emit_trace(
223                        &self.trace_sink,
224                        &self.trace_context,
225                        TraceContext::default().for_llm_call(llm_call_id),
226                        TraceEvent::LlmCallCompleted {
227                            response: crate::trace::trace_llm_response(
228                                response.full_text.clone(),
229                                0,
230                                Some(response.terminal_reason),
231                                crate::trace::trace_output_parts(&response.parts),
232                            ),
233                            usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
234                            provider_usage: response.provider_usage.clone(),
235                            stream_summary: None,
236                        },
237                        self.clock.as_ref(),
238                    );
239                }
240                Ok(response)
241            }
242            Err(error) => {
243                if let Some(llm_call_id) = llm_call_id {
244                    crate::trace::emit_trace(
245                        &self.trace_sink,
246                        &self.trace_context,
247                        TraceContext::default().for_llm_call(llm_call_id),
248                        TraceEvent::LlmCallFailed {
249                            error: TraceError {
250                                message: error.message.clone(),
251                                retryable: error.retryable,
252                                terminal_reason: Some(error.terminal_reason.code().to_string()),
253                                code: error.code.clone(),
254                                raw: error.raw.clone(),
255                            },
256                            stream_summary: None,
257                        },
258                        self.clock.as_ref(),
259                    );
260                }
261                Err(DirectLlmError::from(Box::new(error)))
262            }
263        }
264    }
265}
266
267pub(crate) fn build_llm_request(
268    provider: &ProviderHandle,
269    request: DirectRequest,
270    model: String,
271) -> LlmRequest {
272    let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
273    let DirectRequest {
274        model: _,
275        model_variant,
276        messages,
277        attachments,
278        output,
279        generation,
280        stream_events: _,
281        session_id,
282        caused_by: _,
283        replay: _,
284    } = request;
285
286    let output_spec = match output {
287        DirectOutputSpec::Text => None,
288        DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
289        DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
290            name: schema.name,
291            schema: schema.schema,
292            strict: schema.strict,
293        })),
294    };
295
296    let mut llm_messages = Vec::new();
297    for message in messages {
298        let role = match message.role {
299            DirectRole::System => LlmRole::System,
300            DirectRole::User => LlmRole::User,
301            DirectRole::Assistant => LlmRole::Assistant,
302        };
303        let mut blocks: Vec<LlmContentBlock> = Vec::new();
304        for part in message.parts {
305            match part {
306                DirectPart::Text(text) => {
307                    if !text.is_empty() {
308                        blocks.push(LlmContentBlock::Text {
309                            text: text.into(),
310                            response_meta: None,
311                            cache_breakpoint: false,
312                        });
313                    }
314                }
315                DirectPart::Image(idx) => {
316                    blocks.push(LlmContentBlock::Image {
317                        attachment_idx: idx,
318                    });
319                }
320            }
321        }
322        if !blocks.is_empty() {
323            llm_messages.push(LlmMessage::new(role, blocks));
324        }
325    }
326
327    let scope = match session_id {
328        Some(session_id) => LlmRequestScope::new(
329            session_id.clone(),
330            format!("{session_id}:frame:direct"),
331            format!("{session_id}:direct"),
332        ),
333        None => {
334            let request_id = uuid::Uuid::new_v4().to_string();
335            LlmRequestScope::new(
336                format!("direct:{request_id}"),
337                format!("direct:{request_id}:frame"),
338                request_id,
339            )
340        }
341    };
342
343    LlmRequest {
344        model,
345        messages: llm_messages,
346        attachments,
347        tools: Vec::new().into(),
348        tool_choice: LlmToolChoice::None,
349        model_variant,
350        generation,
351        scope,
352        output_spec,
353        stream_events,
354        provider_trace: None,
355    }
356}
357
358fn validate_direct_output(
359    output: &DirectOutputSpec,
360    response: &LlmResponse,
361) -> Result<(), DirectLlmError> {
362    let DirectOutputSpec::JsonSchema(schema) = output else {
363        return Ok(());
364    };
365    let parsed: serde_json::Value = serde_json::from_str(response.full_text.trim())
366        .map_err(|err| DirectLlmError::InvalidResponse(format!("expected JSON: {err}")))?;
367    LashSchema::new(schema.schema.canonical().clone())
368        .validate(&parsed)
369        .map_err(DirectLlmError::InvalidResponse)
370}
371
372fn transport_stream_events_for_direct(
373    provider: &ProviderHandle,
374    requested: Option<LlmEventSender>,
375) -> Option<LlmEventSender> {
376    if requested.is_some() {
377        return requested;
378    }
379    if provider.requires_streaming() {
380        Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
381    } else {
382        None
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use crate::llm::types::{LlmOutputPart, LlmTerminalReason, LlmUsage};
390    use crate::provider::{ProviderOptions, ProviderReliability};
391    use crate::testing::TestProvider;
392    use serde_json::json;
393    use std::sync::{Arc, Mutex};
394
395    #[test]
396    fn json_schema_request_preserves_output_schema() {
397        let schema = DirectJsonSchema {
398            name: "answer_shape".to_string(),
399            schema: json!({
400                "type": "object",
401                "properties": {
402                    "answer": { "type": "string" }
403                },
404                "required": ["answer"]
405            })
406            .into(),
407            strict: true,
408        };
409
410        let request = DirectRequest::json_schema("model-a", "return json", schema.clone());
411
412        assert_eq!(
413            request.output,
414            DirectOutputSpec::JsonSchema(schema),
415            "DirectRequest::json_schema must carry the requested output schema"
416        );
417    }
418
419    #[test]
420    fn direct_client_provider_accessors_expose_owned_provider_handle() {
421        let provider = TestProvider::builder()
422            .kind("direct-accessor-provider")
423            .serialize_config(|| json!({"provider": "owned"}))
424            .build()
425            .into_handle();
426        let mut client = DirectLlmClient::new(provider);
427
428        assert_eq!(client.provider().kind(), "direct-accessor-provider");
429        assert_eq!(
430            client.provider().to_spec().config,
431            json!({"provider": "owned"})
432        );
433
434        let options = ProviderOptions {
435            reliability: ProviderReliability::default().max_attempts(7),
436            max_output_tokens: Some(123),
437            ..Default::default()
438        };
439        client.provider_mut().set_options(options.clone());
440
441        assert_eq!(client.provider().options(), options);
442    }
443
444    #[tokio::test]
445    async fn direct_client_complete_delegates_to_provider_and_returns_response() {
446        let captured_request: Arc<Mutex<Option<LlmRequest>>> = Arc::new(Mutex::new(None));
447        let captured_for_provider = Arc::clone(&captured_request);
448        let provider = TestProvider::builder()
449            .kind("direct-complete-provider")
450            .complete(move |request| {
451                let captured_for_provider = Arc::clone(&captured_for_provider);
452                async move {
453                    *captured_for_provider.lock().expect("capture lock") = Some(request);
454                    Ok(LlmResponse {
455                        full_text: "provider delegated response".to_string(),
456                        parts: vec![LlmOutputPart::Text {
457                            text: "provider delegated response".to_string(),
458                            response_meta: None,
459                        }],
460                        usage: LlmUsage {
461                            input_tokens: 11,
462                            output_tokens: 3,
463                            ..Default::default()
464                        },
465                        terminal_reason: LlmTerminalReason::Stop,
466                        ..Default::default()
467                    })
468                }
469            })
470            .build()
471            .into_handle();
472        let mut client = DirectLlmClient::new(provider);
473        let mut request = DirectRequest::json("direct-model", "answer as json");
474        request.session_id = Some("direct-session".to_string());
475
476        let response = client
477            .complete(request)
478            .await
479            .expect("direct completion should delegate");
480
481        assert_eq!(response.full_text, "provider delegated response");
482        let captured = captured_request
483            .lock()
484            .expect("capture lock")
485            .clone()
486            .expect("provider should receive a request");
487        assert_eq!(captured.model, "direct-model");
488        assert_eq!(captured.scope.session_id, "direct-session");
489        assert_eq!(captured.scope.agent_frame_id, "direct-session:frame:direct");
490        assert_eq!(captured.scope.request_id, "direct-session:direct");
491        assert!(matches!(
492            captured.output_spec,
493            Some(LlmOutputSpec::JsonObject)
494        ));
495        assert_eq!(captured.messages.len(), 1);
496    }
497
498    #[tokio::test]
499    async fn direct_client_validates_json_schema_output_against_canonical_schema() {
500        let provider = TestProvider::builder()
501            .kind("direct-validation-provider")
502            .complete(|_request| async {
503                Ok(LlmResponse {
504                    full_text: r#"{"items":[]}"#.to_string(),
505                    terminal_reason: LlmTerminalReason::Stop,
506                    ..Default::default()
507                })
508            })
509            .build()
510            .into_handle();
511        let mut client = DirectLlmClient::new(provider);
512        let request = DirectRequest::json_schema(
513            "direct-model",
514            "return items",
515            DirectJsonSchema {
516                name: "items_result".to_string(),
517                schema: json!({
518                    "type": "object",
519                    "required": ["items"],
520                    "properties": {
521                        "items": {
522                            "type": "array",
523                            "minItems": 1,
524                            "items": { "type": "string" }
525                        }
526                    }
527                })
528                .into(),
529                strict: true,
530            },
531        );
532
533        let err = client
534            .complete(request)
535            .await
536            .expect_err("empty items must fail canonical validation");
537
538        assert!(matches!(err, DirectLlmError::InvalidResponse(_)));
539        assert!(err.to_string().contains("items >= 1"));
540    }
541
542    #[test]
543    fn build_llm_request_preserves_nonempty_content_and_drops_empty_messages() {
544        let provider = TestProvider::default().into_handle();
545        let request = DirectRequest {
546            model: "input-model".to_string(),
547            messages: vec![
548                DirectMessage {
549                    role: DirectRole::System,
550                    parts: vec![DirectPart::Text(String::new())],
551                },
552                DirectMessage {
553                    role: DirectRole::User,
554                    parts: vec![
555                        DirectPart::Text("hello".to_string()),
556                        DirectPart::Text(String::new()),
557                    ],
558                },
559                DirectMessage {
560                    role: DirectRole::Assistant,
561                    parts: vec![DirectPart::Image(2)],
562                },
563            ],
564            attachments: Vec::new(),
565            output: DirectOutputSpec::Text,
566            generation: crate::GenerationOptions::default(),
567            stream_events: None,
568            session_id: None,
569            model_variant: None,
570            caused_by: None,
571            replay: None,
572        };
573
574        let llm_request = build_llm_request(&provider, request, "transport-model".to_string());
575
576        assert_eq!(llm_request.model, "transport-model");
577        assert_eq!(
578            llm_request.messages.len(),
579            2,
580            "empty normalized messages must be dropped"
581        );
582        assert_eq!(llm_request.messages[0].role, LlmRole::User);
583        assert_eq!(llm_request.messages[0].blocks.len(), 1);
584        assert!(matches!(
585            &llm_request.messages[0].blocks[0],
586            LlmContentBlock::Text { text, .. } if text.as_ref() == "hello"
587        ));
588        assert_eq!(llm_request.messages[1].role, LlmRole::Assistant);
589        assert!(matches!(
590            &llm_request.messages[1].blocks[0],
591            LlmContentBlock::Image { attachment_idx: 2 }
592        ));
593    }
594
595    #[test]
596    fn build_llm_request_preserves_direct_stream_sender_and_adds_required_noop_sender() {
597        let captured_events: Arc<Mutex<Vec<LlmStreamEvent>>> = Arc::new(Mutex::new(Vec::new()));
598        let captured_for_sender = Arc::clone(&captured_events);
599        let requested_sender = LlmEventSender::new(move |event| {
600            captured_for_sender
601                .lock()
602                .expect("stream event lock")
603                .push(event);
604        });
605        let mut request = DirectRequest::text("model", "prompt");
606        request.stream_events = Some(requested_sender);
607        let provider = TestProvider::default().into_handle();
608
609        let llm_request = build_llm_request(&provider, request, "model".to_string());
610        let sender = llm_request
611            .stream_events
612            .expect("explicit direct stream sender must be preserved");
613        sender.send(LlmStreamEvent::Delta("delta".to_string()));
614        assert_eq!(captured_events.lock().expect("stream event lock").len(), 1);
615
616        let streaming_provider = TestProvider::builder()
617            .requires_streaming(true)
618            .build()
619            .into_handle();
620        let llm_request = build_llm_request(
621            &streaming_provider,
622            DirectRequest::text("model", "prompt"),
623            "model".to_string(),
624        );
625        assert!(
626            llm_request.stream_events.is_some(),
627            "providers that require streaming need a no-op sender even when direct caller did not request one"
628        );
629    }
630}