Skip to main content

lash_core/
direct.rs

1use crate::llm::transport::LlmTransportError;
2use crate::llm::types::{
3    LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
4    LlmRequest, LlmResponse, LlmRole, LlmStreamEvent, LlmToolChoice,
5};
6use crate::provider::ProviderHandle;
7use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
8use std::sync::Arc;
9
10#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum DirectRole {
13    System,
14    User,
15    Assistant,
16}
17
18#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
19pub enum DirectPart {
20    Text(String),
21    Image(usize),
22}
23
24#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
25pub struct DirectMessage {
26    pub role: DirectRole,
27    pub parts: Vec<DirectPart>,
28}
29
30#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31pub struct DirectJsonSchema {
32    pub name: String,
33    pub schema: serde_json::Value,
34    pub strict: bool,
35}
36
37#[derive(Clone, Debug, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
38pub enum DirectOutputSpec {
39    #[default]
40    Text,
41    JsonObject,
42    JsonSchema(DirectJsonSchema),
43}
44
45#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
46pub struct DirectRequest {
47    pub model: String,
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub model_variant: Option<String>,
50    #[serde(default)]
51    pub messages: Vec<DirectMessage>,
52    #[serde(default)]
53    pub attachments: Vec<LlmAttachment>,
54    #[serde(default)]
55    pub output: DirectOutputSpec,
56    #[serde(default)]
57    pub generation: crate::GenerationOptions,
58    #[serde(default, skip)]
59    pub stream_events: Option<LlmEventSender>,
60    #[serde(default, skip_serializing_if = "Option::is_none")]
61    pub session_id: Option<String>,
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub caused_by: Option<crate::CausalRef>,
64    #[serde(default, skip_serializing_if = "Option::is_none")]
65    pub replay: Option<crate::RuntimeReplay>,
66}
67
68impl DirectRequest {
69    pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
70        Self {
71            model: model.into(),
72            model_variant: None,
73            messages: vec![DirectMessage {
74                role: DirectRole::User,
75                parts: vec![DirectPart::Text(prompt.into())],
76            }],
77            attachments: Vec::new(),
78            output: DirectOutputSpec::Text,
79            generation: crate::GenerationOptions::default(),
80            stream_events: None,
81            session_id: None,
82            caused_by: None,
83            replay: None,
84        }
85    }
86
87    pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
88        Self {
89            output: DirectOutputSpec::JsonObject,
90            ..Self::text(model, prompt)
91        }
92    }
93
94    pub fn json_schema(
95        model: impl Into<String>,
96        prompt: impl Into<String>,
97        schema: DirectJsonSchema,
98    ) -> Self {
99        Self {
100            output: DirectOutputSpec::JsonSchema(schema),
101            ..Self::text(model, prompt)
102        }
103    }
104
105    pub fn with_replay_key(mut self, key: impl Into<String>) -> Self {
106        self.replay = Some(crate::RuntimeReplay { key: key.into() });
107        self
108    }
109
110    pub fn with_caused_by(mut self, caused_by: crate::CausalRef) -> Self {
111        self.caused_by = Some(caused_by);
112        self
113    }
114}
115
116#[derive(Debug, thiserror::Error, Clone)]
117pub enum DirectLlmError {
118    #[error("invalid request: {0}")]
119    InvalidRequest(String),
120    #[error("transport error: {0}")]
121    Transport(#[from] LlmTransportError),
122}
123
124pub struct DirectLlmClient {
125    provider: ProviderHandle,
126    trace_sink: Option<Arc<dyn TraceSink>>,
127    trace_context: TraceContext,
128    clock: Arc<dyn crate::Clock>,
129}
130
131impl DirectLlmClient {
132    pub fn new(provider: ProviderHandle) -> Self {
133        Self {
134            provider,
135            trace_sink: None,
136            trace_context: TraceContext::default(),
137            clock: Arc::new(crate::SystemClock),
138        }
139    }
140
141    pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
142        self.trace_sink = sink;
143        self
144    }
145
146    pub fn with_trace_context(mut self, context: TraceContext) -> Self {
147        self.trace_context = context;
148        self
149    }
150
151    pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
152        self.clock = clock;
153        self
154    }
155
156    pub fn provider(&self) -> &ProviderHandle {
157        &self.provider
158    }
159
160    pub fn provider_mut(&mut self) -> &mut ProviderHandle {
161        &mut self.provider
162    }
163
164    pub async fn complete(
165        &mut self,
166        request: DirectRequest,
167    ) -> Result<LlmResponse, DirectLlmError> {
168        if let Some(variant) = request.model_variant.as_deref() {
169            self.provider
170                .validate_variant(&request.model, variant)
171                .map_err(DirectLlmError::InvalidRequest)?;
172        }
173
174        let model = request.model.clone();
175        let llm_request = build_llm_request(&self.provider, request, model);
176        let llm_call_id = if self.trace_sink.is_some() {
177            let id = uuid::Uuid::new_v4().to_string();
178            crate::trace::emit_trace(
179                &self.trace_sink,
180                &self.trace_context,
181                TraceContext::default().for_llm_call(id.clone()),
182                TraceEvent::LlmCallStarted {
183                    request: crate::trace::trace_llm_request(&llm_request),
184                },
185                self.clock.as_ref(),
186            );
187            Some(id)
188        } else {
189            None
190        };
191        match self.provider.complete(llm_request).await {
192            Ok(response) => {
193                if let Some(llm_call_id) = llm_call_id {
194                    crate::trace::emit_trace(
195                        &self.trace_sink,
196                        &self.trace_context,
197                        TraceContext::default().for_llm_call(llm_call_id),
198                        TraceEvent::LlmCallCompleted {
199                            response: crate::trace::trace_llm_response(
200                                response.full_text.clone(),
201                                0,
202                                Some(response.terminal_reason),
203                                crate::trace::trace_output_parts(&response.parts),
204                            ),
205                            usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
206                            provider_usage: response.provider_usage.clone(),
207                            stream_summary: None,
208                        },
209                        self.clock.as_ref(),
210                    );
211                }
212                Ok(response)
213            }
214            Err(error) => {
215                if let Some(llm_call_id) = llm_call_id {
216                    crate::trace::emit_trace(
217                        &self.trace_sink,
218                        &self.trace_context,
219                        TraceContext::default().for_llm_call(llm_call_id),
220                        TraceEvent::LlmCallFailed {
221                            error: TraceError {
222                                message: error.message.clone(),
223                                retryable: error.retryable,
224                                terminal_reason: Some(error.terminal_reason.code().to_string()),
225                                code: error.code.clone(),
226                                raw: error.raw.clone(),
227                            },
228                            stream_summary: None,
229                        },
230                        self.clock.as_ref(),
231                    );
232                }
233                Err(DirectLlmError::from(error))
234            }
235        }
236    }
237}
238
239pub(crate) fn build_llm_request(
240    provider: &ProviderHandle,
241    request: DirectRequest,
242    model: String,
243) -> LlmRequest {
244    let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
245    let DirectRequest {
246        model: _,
247        model_variant,
248        messages,
249        attachments,
250        output,
251        generation,
252        stream_events: _,
253        session_id,
254        caused_by: _,
255        replay: _,
256    } = request;
257
258    let output_spec = match output {
259        DirectOutputSpec::Text => None,
260        DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
261        DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
262            name: schema.name,
263            schema: schema.schema,
264            strict: schema.strict,
265        })),
266    };
267
268    let mut llm_messages = Vec::new();
269    for message in messages {
270        let role = match message.role {
271            DirectRole::System => LlmRole::System,
272            DirectRole::User => LlmRole::User,
273            DirectRole::Assistant => LlmRole::Assistant,
274        };
275        let mut blocks: Vec<LlmContentBlock> = Vec::new();
276        for part in message.parts {
277            match part {
278                DirectPart::Text(text) => {
279                    if !text.is_empty() {
280                        blocks.push(LlmContentBlock::Text {
281                            text: text.into(),
282                            response_meta: None,
283                            cache_breakpoint: false,
284                        });
285                    }
286                }
287                DirectPart::Image(idx) => {
288                    blocks.push(LlmContentBlock::Image {
289                        attachment_idx: idx,
290                    });
291                }
292            }
293        }
294        if !blocks.is_empty() {
295            llm_messages.push(LlmMessage::new(role, blocks));
296        }
297    }
298
299    LlmRequest {
300        model,
301        messages: llm_messages,
302        attachments,
303        tools: Vec::new().into(),
304        tool_choice: LlmToolChoice::None,
305        model_variant,
306        generation,
307        session_id,
308        output_spec,
309        stream_events,
310        provider_trace: None,
311    }
312}
313
314fn transport_stream_events_for_direct(
315    provider: &ProviderHandle,
316    requested: Option<LlmEventSender>,
317) -> Option<LlmEventSender> {
318    if requested.is_some() {
319        return requested;
320    }
321    if provider.requires_streaming() {
322        Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
323    } else {
324        None
325    }
326}