1use crate::llm::transport::LlmTransportError;
2use crate::llm::types::{
3 LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
4 LlmRequest, LlmResponse, LlmRole, LlmStreamEvent, LlmToolChoice,
5};
6use crate::provider::ProviderHandle;
7use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
8use std::sync::Arc;
9
10#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum DirectRole {
13 System,
14 User,
15 Assistant,
16}
17
18#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
19pub enum DirectPart {
20 Text(String),
21 Image(usize),
22}
23
24#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
25pub struct DirectMessage {
26 pub role: DirectRole,
27 pub parts: Vec<DirectPart>,
28}
29
30#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
31pub struct DirectJsonSchema {
32 pub name: String,
33 pub schema: serde_json::Value,
34 pub strict: bool,
35}
36
37#[derive(Clone, Debug, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
38pub enum DirectOutputSpec {
39 #[default]
40 Text,
41 JsonObject,
42 JsonSchema(DirectJsonSchema),
43}
44
45#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
46pub struct DirectRequest {
47 pub model: String,
48 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub model_variant: Option<String>,
50 #[serde(default)]
51 pub messages: Vec<DirectMessage>,
52 #[serde(default)]
53 pub attachments: Vec<LlmAttachment>,
54 #[serde(default)]
55 pub output: DirectOutputSpec,
56 #[serde(default)]
57 pub generation: crate::GenerationOptions,
58 #[serde(default, skip)]
59 pub stream_events: Option<LlmEventSender>,
60 #[serde(default, skip_serializing_if = "Option::is_none")]
61 pub session_id: Option<String>,
62 #[serde(default, skip_serializing_if = "Option::is_none")]
63 pub caused_by: Option<crate::CausalRef>,
64 #[serde(default, skip_serializing_if = "Option::is_none")]
65 pub replay: Option<crate::RuntimeReplay>,
66}
67
68impl DirectRequest {
69 pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
70 Self {
71 model: model.into(),
72 model_variant: None,
73 messages: vec![DirectMessage {
74 role: DirectRole::User,
75 parts: vec![DirectPart::Text(prompt.into())],
76 }],
77 attachments: Vec::new(),
78 output: DirectOutputSpec::Text,
79 generation: crate::GenerationOptions::default(),
80 stream_events: None,
81 session_id: None,
82 caused_by: None,
83 replay: None,
84 }
85 }
86
87 pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
88 Self {
89 output: DirectOutputSpec::JsonObject,
90 ..Self::text(model, prompt)
91 }
92 }
93
94 pub fn json_schema(
95 model: impl Into<String>,
96 prompt: impl Into<String>,
97 schema: DirectJsonSchema,
98 ) -> Self {
99 Self {
100 output: DirectOutputSpec::JsonSchema(schema),
101 ..Self::text(model, prompt)
102 }
103 }
104
105 pub fn with_replay_key(mut self, key: impl Into<String>) -> Self {
106 self.replay = Some(crate::RuntimeReplay { key: key.into() });
107 self
108 }
109
110 pub fn with_caused_by(mut self, caused_by: crate::CausalRef) -> Self {
111 self.caused_by = Some(caused_by);
112 self
113 }
114}
115
116#[derive(Debug, thiserror::Error, Clone)]
117pub enum DirectLlmError {
118 #[error("invalid request: {0}")]
119 InvalidRequest(String),
120 #[error("transport error: {0}")]
121 Transport(#[from] LlmTransportError),
122}
123
124pub struct DirectLlmClient {
125 provider: ProviderHandle,
126 trace_sink: Option<Arc<dyn TraceSink>>,
127 trace_context: TraceContext,
128}
129
130impl DirectLlmClient {
131 pub fn new(provider: ProviderHandle) -> Self {
132 Self {
133 provider,
134 trace_sink: None,
135 trace_context: TraceContext::default(),
136 }
137 }
138
139 pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
140 self.trace_sink = sink;
141 self
142 }
143
144 pub fn with_trace_context(mut self, context: TraceContext) -> Self {
145 self.trace_context = context;
146 self
147 }
148
149 pub fn provider(&self) -> &ProviderHandle {
150 &self.provider
151 }
152
153 pub fn provider_mut(&mut self) -> &mut ProviderHandle {
154 &mut self.provider
155 }
156
157 pub async fn complete(
158 &mut self,
159 request: DirectRequest,
160 ) -> Result<LlmResponse, DirectLlmError> {
161 if let Some(variant) = request.model_variant.as_deref() {
162 self.provider
163 .validate_variant(&request.model, variant)
164 .map_err(DirectLlmError::InvalidRequest)?;
165 }
166
167 let model = request.model.clone();
168 let llm_request = build_llm_request(&self.provider, request, model);
169 let llm_call_id = if self.trace_sink.is_some() {
170 let id = uuid::Uuid::new_v4().to_string();
171 crate::trace::emit_trace(
172 &self.trace_sink,
173 &self.trace_context,
174 TraceContext::default().for_llm_call(id.clone()),
175 TraceEvent::LlmCallStarted {
176 request: crate::trace::trace_llm_request(&llm_request),
177 },
178 );
179 Some(id)
180 } else {
181 None
182 };
183 match self.provider.complete(llm_request).await {
184 Ok(response) => {
185 if let Some(llm_call_id) = llm_call_id {
186 crate::trace::emit_trace(
187 &self.trace_sink,
188 &self.trace_context,
189 TraceContext::default().for_llm_call(llm_call_id),
190 TraceEvent::LlmCallCompleted {
191 response: crate::trace::trace_llm_response(
192 response.full_text.clone(),
193 0,
194 Some(response.terminal_reason),
195 crate::trace::trace_output_parts(&response.parts),
196 ),
197 usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
198 provider_usage: response.provider_usage.clone(),
199 stream_summary: None,
200 },
201 );
202 }
203 Ok(response)
204 }
205 Err(error) => {
206 if let Some(llm_call_id) = llm_call_id {
207 crate::trace::emit_trace(
208 &self.trace_sink,
209 &self.trace_context,
210 TraceContext::default().for_llm_call(llm_call_id),
211 TraceEvent::LlmCallFailed {
212 error: TraceError {
213 message: error.message.clone(),
214 retryable: error.retryable,
215 terminal_reason: Some(error.terminal_reason.code().to_string()),
216 code: error.code.clone(),
217 raw: error.raw.clone(),
218 },
219 stream_summary: None,
220 },
221 );
222 }
223 Err(DirectLlmError::from(error))
224 }
225 }
226 }
227}
228
229pub(crate) fn build_llm_request(
230 provider: &ProviderHandle,
231 request: DirectRequest,
232 model: String,
233) -> LlmRequest {
234 let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
235 let DirectRequest {
236 model: _,
237 model_variant,
238 messages,
239 attachments,
240 output,
241 generation,
242 stream_events: _,
243 session_id,
244 caused_by: _,
245 replay: _,
246 } = request;
247
248 let output_spec = match output {
249 DirectOutputSpec::Text => None,
250 DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
251 DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
252 name: schema.name,
253 schema: schema.schema,
254 strict: schema.strict,
255 })),
256 };
257
258 let mut llm_messages = Vec::new();
259 for message in messages {
260 let role = match message.role {
261 DirectRole::System => LlmRole::System,
262 DirectRole::User => LlmRole::User,
263 DirectRole::Assistant => LlmRole::Assistant,
264 };
265 let mut blocks: Vec<LlmContentBlock> = Vec::new();
266 for part in message.parts {
267 match part {
268 DirectPart::Text(text) => {
269 if !text.is_empty() {
270 blocks.push(LlmContentBlock::Text {
271 text: text.into(),
272 response_meta: None,
273 cache_breakpoint: false,
274 });
275 }
276 }
277 DirectPart::Image(idx) => {
278 blocks.push(LlmContentBlock::Image {
279 attachment_idx: idx,
280 });
281 }
282 }
283 }
284 if !blocks.is_empty() {
285 llm_messages.push(LlmMessage::new(role, blocks));
286 }
287 }
288
289 LlmRequest {
290 model,
291 messages: llm_messages,
292 attachments,
293 tools: Vec::new().into(),
294 tool_choice: LlmToolChoice::None,
295 model_variant,
296 generation,
297 session_id,
298 output_spec,
299 stream_events,
300 provider_trace: None,
301 }
302}
303
304fn transport_stream_events_for_direct(
305 provider: &ProviderHandle,
306 requested: Option<LlmEventSender>,
307) -> Option<LlmEventSender> {
308 if requested.is_some() {
309 return requested;
310 }
311 if provider.requires_streaming() {
312 Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
313 } else {
314 None
315 }
316}