1use crate::llm::transport::LlmTransportError;
2use crate::llm::types::{
3 LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
4 LlmRequest, LlmResponse, LlmRole, LlmStreamEvent, LlmToolChoice,
5};
6use crate::provider::ProviderHandle;
7use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
8use std::sync::Arc;
9
10#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum DirectRole {
13 System,
14 User,
15 Assistant,
16}
17
18#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
19pub enum DirectPart {
20 Text(String),
21 Image(usize),
22}
23
24#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
25pub struct DirectMessage {
26 pub role: DirectRole,
27 pub parts: Vec<DirectPart>,
28}
29
30#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31pub struct DirectJsonSchema {
32 pub name: String,
33 pub schema: serde_json::Value,
34 pub strict: bool,
35}
36
37#[derive(Clone, Debug, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
38pub enum DirectOutputSpec {
39 #[default]
40 Text,
41 JsonObject,
42 JsonSchema(DirectJsonSchema),
43}
44
45#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
46pub struct DirectRequest {
47 pub model: String,
48 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub model_variant: Option<String>,
50 #[serde(default)]
51 pub messages: Vec<DirectMessage>,
52 #[serde(default)]
53 pub attachments: Vec<LlmAttachment>,
54 #[serde(default)]
55 pub output: DirectOutputSpec,
56 #[serde(default)]
57 pub generation: crate::GenerationOptions,
58 #[serde(default, skip)]
59 pub stream_events: Option<LlmEventSender>,
60 #[serde(default, skip_serializing_if = "Option::is_none")]
61 pub session_id: Option<String>,
62 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub caused_by: Option<crate::CausalRef>,
64 #[serde(default, skip_serializing_if = "Option::is_none")]
65 pub replay: Option<crate::RuntimeReplay>,
66}
67
68impl DirectRequest {
69 pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
70 Self {
71 model: model.into(),
72 model_variant: None,
73 messages: vec![DirectMessage {
74 role: DirectRole::User,
75 parts: vec![DirectPart::Text(prompt.into())],
76 }],
77 attachments: Vec::new(),
78 output: DirectOutputSpec::Text,
79 generation: crate::GenerationOptions::default(),
80 stream_events: None,
81 session_id: None,
82 caused_by: None,
83 replay: None,
84 }
85 }
86
87 pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
88 Self {
89 output: DirectOutputSpec::JsonObject,
90 ..Self::text(model, prompt)
91 }
92 }
93
94 pub fn json_schema(
95 model: impl Into<String>,
96 prompt: impl Into<String>,
97 schema: DirectJsonSchema,
98 ) -> Self {
99 Self {
100 output: DirectOutputSpec::JsonSchema(schema),
101 ..Self::text(model, prompt)
102 }
103 }
104
105 pub fn with_replay_key(mut self, key: impl Into<String>) -> Self {
106 self.replay = Some(crate::RuntimeReplay { key: key.into() });
107 self
108 }
109
110 pub fn with_caused_by(mut self, caused_by: crate::CausalRef) -> Self {
111 self.caused_by = Some(caused_by);
112 self
113 }
114}
115
116#[derive(Debug, thiserror::Error, Clone)]
117pub enum DirectLlmError {
118 #[error("invalid request: {0}")]
119 InvalidRequest(String),
120 #[error("transport error: {0}")]
121 Transport(#[from] LlmTransportError),
122}
123
124pub struct DirectLlmClient {
125 provider: ProviderHandle,
126 trace_sink: Option<Arc<dyn TraceSink>>,
127 trace_context: TraceContext,
128 clock: Arc<dyn crate::Clock>,
129}
130
131impl DirectLlmClient {
132 pub fn new(provider: ProviderHandle) -> Self {
133 Self {
134 provider,
135 trace_sink: None,
136 trace_context: TraceContext::default(),
137 clock: Arc::new(crate::SystemClock),
138 }
139 }
140
141 pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
142 self.trace_sink = sink;
143 self
144 }
145
146 pub fn with_trace_context(mut self, context: TraceContext) -> Self {
147 self.trace_context = context;
148 self
149 }
150
151 pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
152 self.clock = clock;
153 self
154 }
155
156 pub fn provider(&self) -> &ProviderHandle {
157 &self.provider
158 }
159
160 pub fn provider_mut(&mut self) -> &mut ProviderHandle {
161 &mut self.provider
162 }
163
164 pub async fn complete(
165 &mut self,
166 request: DirectRequest,
167 ) -> Result<LlmResponse, DirectLlmError> {
168 if let Some(variant) = request.model_variant.as_deref() {
169 self.provider
170 .validate_variant(&request.model, variant)
171 .map_err(DirectLlmError::InvalidRequest)?;
172 }
173
174 let model = request.model.clone();
175 let llm_request = build_llm_request(&self.provider, request, model);
176 let llm_call_id = if self.trace_sink.is_some() {
177 let id = uuid::Uuid::new_v4().to_string();
178 crate::trace::emit_trace(
179 &self.trace_sink,
180 &self.trace_context,
181 TraceContext::default().for_llm_call(id.clone()),
182 TraceEvent::LlmCallStarted {
183 request: crate::trace::trace_llm_request(&llm_request),
184 },
185 self.clock.as_ref(),
186 );
187 Some(id)
188 } else {
189 None
190 };
191 match self.provider.complete(llm_request).await {
192 Ok(response) => {
193 if let Some(llm_call_id) = llm_call_id {
194 crate::trace::emit_trace(
195 &self.trace_sink,
196 &self.trace_context,
197 TraceContext::default().for_llm_call(llm_call_id),
198 TraceEvent::LlmCallCompleted {
199 response: crate::trace::trace_llm_response(
200 response.full_text.clone(),
201 0,
202 Some(response.terminal_reason),
203 crate::trace::trace_output_parts(&response.parts),
204 ),
205 usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
206 provider_usage: response.provider_usage.clone(),
207 stream_summary: None,
208 },
209 self.clock.as_ref(),
210 );
211 }
212 Ok(response)
213 }
214 Err(error) => {
215 if let Some(llm_call_id) = llm_call_id {
216 crate::trace::emit_trace(
217 &self.trace_sink,
218 &self.trace_context,
219 TraceContext::default().for_llm_call(llm_call_id),
220 TraceEvent::LlmCallFailed {
221 error: TraceError {
222 message: error.message.clone(),
223 retryable: error.retryable,
224 terminal_reason: Some(error.terminal_reason.code().to_string()),
225 code: error.code.clone(),
226 raw: error.raw.clone(),
227 },
228 stream_summary: None,
229 },
230 self.clock.as_ref(),
231 );
232 }
233 Err(DirectLlmError::from(error))
234 }
235 }
236 }
237}
238
239pub(crate) fn build_llm_request(
240 provider: &ProviderHandle,
241 request: DirectRequest,
242 model: String,
243) -> LlmRequest {
244 let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
245 let DirectRequest {
246 model: _,
247 model_variant,
248 messages,
249 attachments,
250 output,
251 generation,
252 stream_events: _,
253 session_id,
254 caused_by: _,
255 replay: _,
256 } = request;
257
258 let output_spec = match output {
259 DirectOutputSpec::Text => None,
260 DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
261 DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
262 name: schema.name,
263 schema: schema.schema,
264 strict: schema.strict,
265 })),
266 };
267
268 let mut llm_messages = Vec::new();
269 for message in messages {
270 let role = match message.role {
271 DirectRole::System => LlmRole::System,
272 DirectRole::User => LlmRole::User,
273 DirectRole::Assistant => LlmRole::Assistant,
274 };
275 let mut blocks: Vec<LlmContentBlock> = Vec::new();
276 for part in message.parts {
277 match part {
278 DirectPart::Text(text) => {
279 if !text.is_empty() {
280 blocks.push(LlmContentBlock::Text {
281 text: text.into(),
282 response_meta: None,
283 cache_breakpoint: false,
284 });
285 }
286 }
287 DirectPart::Image(idx) => {
288 blocks.push(LlmContentBlock::Image {
289 attachment_idx: idx,
290 });
291 }
292 }
293 }
294 if !blocks.is_empty() {
295 llm_messages.push(LlmMessage::new(role, blocks));
296 }
297 }
298
299 LlmRequest {
300 model,
301 messages: llm_messages,
302 attachments,
303 tools: Vec::new().into(),
304 tool_choice: LlmToolChoice::None,
305 model_variant,
306 generation,
307 session_id,
308 output_spec,
309 stream_events,
310 provider_trace: None,
311 }
312}
313
314fn transport_stream_events_for_direct(
315 provider: &ProviderHandle,
316 requested: Option<LlmEventSender>,
317) -> Option<LlmEventSender> {
318 if requested.is_some() {
319 return requested;
320 }
321 if provider.requires_streaming() {
322 Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
323 } else {
324 None
325 }
326}