Skip to main content

lash_core/
direct.rs

1use crate::llm::transport::LlmTransportError;
2use crate::llm::types::{
3    LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
4    LlmRequest, LlmResponse, LlmRole, LlmStreamEvent, LlmToolChoice,
5};
6use crate::provider::ProviderHandle;
7use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
8use std::sync::Arc;
9
10#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum DirectRole {
13    System,
14    User,
15    Assistant,
16}
17
18#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
19pub enum DirectPart {
20    Text(String),
21    Image(usize),
22}
23
24#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
25pub struct DirectMessage {
26    pub role: DirectRole,
27    pub parts: Vec<DirectPart>,
28}
29
30#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31pub struct DirectJsonSchema {
32    pub name: String,
33    pub schema: serde_json::Value,
34    pub strict: bool,
35}
36
37#[derive(Clone, Debug, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
38pub enum DirectOutputSpec {
39    #[default]
40    Text,
41    JsonObject,
42    JsonSchema(DirectJsonSchema),
43}
44
45#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
46pub struct DirectRequest {
47    pub model: String,
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub model_variant: Option<String>,
50    #[serde(default)]
51    pub messages: Vec<DirectMessage>,
52    #[serde(default)]
53    pub attachments: Vec<LlmAttachment>,
54    #[serde(default)]
55    pub output: DirectOutputSpec,
56    #[serde(default)]
57    pub generation: crate::GenerationOptions,
58    #[serde(default, skip)]
59    pub stream_events: Option<LlmEventSender>,
60    #[serde(default, skip_serializing_if = "Option::is_none")]
61    pub session_id: Option<String>,
62    #[serde(default, skip_serializing_if = "Option::is_none")]
63    pub caused_by: Option<crate::CausalRef>,
64    #[serde(default, skip_serializing_if = "Option::is_none")]
65    pub replay: Option<crate::RuntimeReplay>,
66}
67
68impl DirectRequest {
69    pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
70        Self {
71            model: model.into(),
72            model_variant: None,
73            messages: vec![DirectMessage {
74                role: DirectRole::User,
75                parts: vec![DirectPart::Text(prompt.into())],
76            }],
77            attachments: Vec::new(),
78            output: DirectOutputSpec::Text,
79            generation: crate::GenerationOptions::default(),
80            stream_events: None,
81            session_id: None,
82            caused_by: None,
83            replay: None,
84        }
85    }
86
87    pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
88        Self {
89            output: DirectOutputSpec::JsonObject,
90            ..Self::text(model, prompt)
91        }
92    }
93
94    pub fn json_schema(
95        model: impl Into<String>,
96        prompt: impl Into<String>,
97        schema: DirectJsonSchema,
98    ) -> Self {
99        Self {
100            output: DirectOutputSpec::JsonSchema(schema),
101            ..Self::text(model, prompt)
102        }
103    }
104
105    pub fn with_replay_key(mut self, key: impl Into<String>) -> Self {
106        self.replay = Some(crate::RuntimeReplay { key: key.into() });
107        self
108    }
109
110    pub fn with_caused_by(mut self, caused_by: crate::CausalRef) -> Self {
111        self.caused_by = Some(caused_by);
112        self
113    }
114}
115
116#[derive(Debug, thiserror::Error, Clone)]
117pub enum DirectLlmError {
118    #[error("invalid request: {0}")]
119    InvalidRequest(String),
120    #[error("transport error: {0}")]
121    Transport(#[from] LlmTransportError),
122}
123
124pub struct DirectLlmClient {
125    provider: ProviderHandle,
126    trace_sink: Option<Arc<dyn TraceSink>>,
127    trace_context: TraceContext,
128}
129
130impl DirectLlmClient {
131    pub fn new(provider: ProviderHandle) -> Self {
132        Self {
133            provider,
134            trace_sink: None,
135            trace_context: TraceContext::default(),
136        }
137    }
138
139    pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
140        self.trace_sink = sink;
141        self
142    }
143
144    pub fn with_trace_context(mut self, context: TraceContext) -> Self {
145        self.trace_context = context;
146        self
147    }
148
149    pub fn provider(&self) -> &ProviderHandle {
150        &self.provider
151    }
152
153    pub fn provider_mut(&mut self) -> &mut ProviderHandle {
154        &mut self.provider
155    }
156
157    pub async fn complete(
158        &mut self,
159        request: DirectRequest,
160    ) -> Result<LlmResponse, DirectLlmError> {
161        if let Some(variant) = request.model_variant.as_deref() {
162            self.provider
163                .validate_variant(&request.model, variant)
164                .map_err(DirectLlmError::InvalidRequest)?;
165        }
166
167        let model = request.model.clone();
168        let llm_request = build_llm_request(&self.provider, request, model);
169        let llm_call_id = if self.trace_sink.is_some() {
170            let id = uuid::Uuid::new_v4().to_string();
171            crate::trace::emit_trace(
172                &self.trace_sink,
173                &self.trace_context,
174                TraceContext::default().for_llm_call(id.clone()),
175                TraceEvent::LlmCallStarted {
176                    request: crate::trace::trace_llm_request(&llm_request),
177                },
178            );
179            Some(id)
180        } else {
181            None
182        };
183        match self.provider.complete(llm_request).await {
184            Ok(response) => {
185                if let Some(llm_call_id) = llm_call_id {
186                    crate::trace::emit_trace(
187                        &self.trace_sink,
188                        &self.trace_context,
189                        TraceContext::default().for_llm_call(llm_call_id),
190                        TraceEvent::LlmCallCompleted {
191                            response: crate::trace::trace_llm_response(
192                                response.full_text.clone(),
193                                0,
194                                Some(response.terminal_reason),
195                                crate::trace::trace_output_parts(&response.parts),
196                            ),
197                            usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
198                            provider_usage: response.provider_usage.clone(),
199                            stream_summary: None,
200                        },
201                    );
202                }
203                Ok(response)
204            }
205            Err(error) => {
206                if let Some(llm_call_id) = llm_call_id {
207                    crate::trace::emit_trace(
208                        &self.trace_sink,
209                        &self.trace_context,
210                        TraceContext::default().for_llm_call(llm_call_id),
211                        TraceEvent::LlmCallFailed {
212                            error: TraceError {
213                                message: error.message.clone(),
214                                retryable: error.retryable,
215                                terminal_reason: Some(error.terminal_reason.code().to_string()),
216                                code: error.code.clone(),
217                                raw: error.raw.clone(),
218                            },
219                            stream_summary: None,
220                        },
221                    );
222                }
223                Err(DirectLlmError::from(error))
224            }
225        }
226    }
227}
228
229pub(crate) fn build_llm_request(
230    provider: &ProviderHandle,
231    request: DirectRequest,
232    model: String,
233) -> LlmRequest {
234    let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
235    let DirectRequest {
236        model: _,
237        model_variant,
238        messages,
239        attachments,
240        output,
241        generation,
242        stream_events: _,
243        session_id,
244        caused_by: _,
245        replay: _,
246    } = request;
247
248    let output_spec = match output {
249        DirectOutputSpec::Text => None,
250        DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
251        DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
252            name: schema.name,
253            schema: schema.schema,
254            strict: schema.strict,
255        })),
256    };
257
258    let mut llm_messages = Vec::new();
259    for message in messages {
260        let role = match message.role {
261            DirectRole::System => LlmRole::System,
262            DirectRole::User => LlmRole::User,
263            DirectRole::Assistant => LlmRole::Assistant,
264        };
265        let mut blocks: Vec<LlmContentBlock> = Vec::new();
266        for part in message.parts {
267            match part {
268                DirectPart::Text(text) => {
269                    if !text.is_empty() {
270                        blocks.push(LlmContentBlock::Text {
271                            text: text.into(),
272                            response_meta: None,
273                            cache_breakpoint: false,
274                        });
275                    }
276                }
277                DirectPart::Image(idx) => {
278                    blocks.push(LlmContentBlock::Image {
279                        attachment_idx: idx,
280                    });
281                }
282            }
283        }
284        if !blocks.is_empty() {
285            llm_messages.push(LlmMessage::new(role, blocks));
286        }
287    }
288
289    LlmRequest {
290        model,
291        messages: llm_messages,
292        attachments,
293        tools: Vec::new().into(),
294        tool_choice: LlmToolChoice::None,
295        model_variant,
296        generation,
297        session_id,
298        output_spec,
299        stream_events,
300        provider_trace: None,
301    }
302}
303
304fn transport_stream_events_for_direct(
305    provider: &ProviderHandle,
306    requested: Option<LlmEventSender>,
307) -> Option<LlmEventSender> {
308    if requested.is_some() {
309        return requested;
310    }
311    if provider.requires_streaming() {
312        Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
313    } else {
314        None
315    }
316}