1use crate::llm::transport::LlmTransportError;
2use crate::llm::types::{
3 LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
4 LlmRequest, LlmResponse, LlmRole, LlmStreamEvent, LlmToolChoice,
5};
6use crate::provider::ProviderHandle;
7use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
8use std::sync::Arc;
9
10#[derive(Clone, Debug, PartialEq, Eq)]
11pub enum DirectRole {
12 System,
13 User,
14 Assistant,
15}
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub enum DirectPart {
19 Text(String),
20 Image(usize),
21}
22
23#[derive(Clone, Debug, PartialEq, Eq)]
24pub struct DirectMessage {
25 pub role: DirectRole,
26 pub parts: Vec<DirectPart>,
27}
28
29#[derive(Clone, Debug, PartialEq, Eq)]
30pub struct DirectJsonSchema {
31 pub name: String,
32 pub schema: serde_json::Value,
33 pub strict: bool,
34}
35
36#[derive(Clone, Debug, PartialEq, Eq, Default)]
37pub enum DirectOutputSpec {
38 #[default]
39 Text,
40 JsonObject,
41 JsonSchema(DirectJsonSchema),
42}
43
44#[derive(Clone, Debug)]
45pub struct DirectRequest {
46 pub model: String,
47 pub model_variant: Option<String>,
48 pub messages: Vec<DirectMessage>,
49 pub attachments: Vec<LlmAttachment>,
50 pub output: DirectOutputSpec,
51 pub stream_events: Option<LlmEventSender>,
52 pub session_id: Option<String>,
53 pub originating_tool_call_id: Option<String>,
58}
59
60impl DirectRequest {
61 pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
62 Self {
63 model: model.into(),
64 model_variant: None,
65 messages: vec![DirectMessage {
66 role: DirectRole::User,
67 parts: vec![DirectPart::Text(prompt.into())],
68 }],
69 attachments: Vec::new(),
70 output: DirectOutputSpec::Text,
71 stream_events: None,
72 session_id: None,
73 originating_tool_call_id: None,
74 }
75 }
76
77 pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
78 Self {
79 output: DirectOutputSpec::JsonObject,
80 ..Self::text(model, prompt)
81 }
82 }
83
84 pub fn json_schema(
85 model: impl Into<String>,
86 prompt: impl Into<String>,
87 schema: DirectJsonSchema,
88 ) -> Self {
89 Self {
90 output: DirectOutputSpec::JsonSchema(schema),
91 ..Self::text(model, prompt)
92 }
93 }
94}
95
96#[derive(Debug, thiserror::Error, Clone)]
97pub enum DirectLlmError {
98 #[error("invalid request: {0}")]
99 InvalidRequest(String),
100 #[error("transport error: {0}")]
101 Transport(#[from] LlmTransportError),
102}
103
104pub struct DirectLlmClient {
105 provider: ProviderHandle,
106 trace_sink: Option<Arc<dyn TraceSink>>,
107 trace_context: TraceContext,
108}
109
110impl DirectLlmClient {
111 pub fn new(provider: ProviderHandle) -> Self {
112 Self {
113 provider,
114 trace_sink: None,
115 trace_context: TraceContext::default(),
116 }
117 }
118
119 pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
120 self.trace_sink = sink;
121 self
122 }
123
124 pub fn with_trace_context(mut self, context: TraceContext) -> Self {
125 self.trace_context = context;
126 self
127 }
128
129 pub fn provider(&self) -> &ProviderHandle {
130 &self.provider
131 }
132
133 pub fn provider_mut(&mut self) -> &mut ProviderHandle {
134 &mut self.provider
135 }
136
137 pub async fn complete(
138 &mut self,
139 request: DirectRequest,
140 ) -> Result<LlmResponse, DirectLlmError> {
141 let normalized_model = self.provider.resolve_model(&request.model);
142 if let Some(variant) = request.model_variant.as_deref() {
143 self.provider
144 .validate_variant(&normalized_model, variant)
145 .map_err(DirectLlmError::InvalidRequest)?;
146 }
147
148 let llm_request = build_llm_request(&self.provider, request, normalized_model);
149 let llm_call_id = if self.trace_sink.is_some() {
150 let id = uuid::Uuid::new_v4().to_string();
151 crate::trace::emit_trace(
152 &self.trace_sink,
153 &self.trace_context,
154 TraceContext::default().for_llm_call(id.clone()),
155 TraceEvent::LlmCallStarted {
156 request: crate::trace::trace_llm_request(&llm_request),
157 },
158 );
159 Some(id)
160 } else {
161 None
162 };
163 match self.provider.complete(llm_request).await {
164 Ok(response) => {
165 if let Some(llm_call_id) = llm_call_id {
166 crate::trace::emit_trace(
167 &self.trace_sink,
168 &self.trace_context,
169 TraceContext::default().for_llm_call(llm_call_id),
170 TraceEvent::LlmCallCompleted {
171 response: crate::trace::trace_llm_response(
172 response.full_text.clone(),
173 0,
174 Some(response.terminal_reason),
175 crate::trace::trace_output_parts(&response.parts),
176 ),
177 usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
178 provider_usage: response.provider_usage.clone(),
179 stream_summary: None,
180 },
181 );
182 }
183 Ok(response)
184 }
185 Err(error) => {
186 if let Some(llm_call_id) = llm_call_id {
187 crate::trace::emit_trace(
188 &self.trace_sink,
189 &self.trace_context,
190 TraceContext::default().for_llm_call(llm_call_id),
191 TraceEvent::LlmCallFailed {
192 error: TraceError {
193 message: error.message.clone(),
194 retryable: error.retryable,
195 terminal_reason: Some(error.terminal_reason.code().to_string()),
196 code: error.code.clone(),
197 raw: error.raw.clone(),
198 },
199 stream_summary: None,
200 },
201 );
202 }
203 Err(DirectLlmError::from(error))
204 }
205 }
206 }
207}
208
209pub(crate) fn build_llm_request(
210 provider: &ProviderHandle,
211 request: DirectRequest,
212 model: String,
213) -> LlmRequest {
214 let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
215 let DirectRequest {
216 model: _,
217 model_variant,
218 messages,
219 attachments,
220 output,
221 stream_events: _,
222 session_id,
223 originating_tool_call_id: _,
224 } = request;
225
226 let output_spec = match output {
227 DirectOutputSpec::Text => None,
228 DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
229 DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
230 name: schema.name,
231 schema: schema.schema,
232 strict: schema.strict,
233 })),
234 };
235
236 let mut llm_messages = Vec::new();
237 for message in messages {
238 let role = match message.role {
239 DirectRole::System => LlmRole::System,
240 DirectRole::User => LlmRole::User,
241 DirectRole::Assistant => LlmRole::Assistant,
242 };
243 let mut blocks: Vec<LlmContentBlock> = Vec::new();
244 for part in message.parts {
245 match part {
246 DirectPart::Text(text) => {
247 if !text.is_empty() {
248 blocks.push(LlmContentBlock::Text {
249 text: text.into(),
250 response_meta: None,
251 cache_breakpoint: false,
252 });
253 }
254 }
255 DirectPart::Image(idx) => {
256 blocks.push(LlmContentBlock::Image {
257 attachment_idx: idx,
258 });
259 }
260 }
261 }
262 if !blocks.is_empty() {
263 llm_messages.push(LlmMessage::new(role, blocks));
264 }
265 }
266
267 LlmRequest {
268 model,
269 messages: llm_messages,
270 attachments,
271 tools: Vec::new().into(),
272 tool_choice: LlmToolChoice::None,
273 model_variant,
274 session_id,
275 output_spec,
276 stream_events,
277 provider_trace: None,
278 }
279}
280
281fn transport_stream_events_for_direct(
282 provider: &ProviderHandle,
283 requested: Option<LlmEventSender>,
284) -> Option<LlmEventSender> {
285 if requested.is_some() {
286 return requested;
287 }
288 if provider.requires_streaming() {
289 Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
290 } else {
291 None
292 }
293}