Skip to main content

lash_core/
direct.rs

1use crate::llm::transport::LlmTransportError;
2use crate::llm::types::{
3    LlmAttachment, LlmContentBlock, LlmEventSender, LlmJsonSchema, LlmMessage, LlmOutputSpec,
4    LlmRequest, LlmResponse, LlmRole, LlmStreamEvent, LlmToolChoice,
5};
6use crate::provider::ProviderHandle;
7use lash_trace::{TraceContext, TraceError, TraceEvent, TraceSink};
8use std::sync::Arc;
9
10#[derive(Clone, Debug, PartialEq, Eq)]
11pub enum DirectRole {
12    System,
13    User,
14    Assistant,
15}
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18pub enum DirectPart {
19    Text(String),
20    Image(usize),
21}
22
23#[derive(Clone, Debug, PartialEq, Eq)]
24pub struct DirectMessage {
25    pub role: DirectRole,
26    pub parts: Vec<DirectPart>,
27}
28
29#[derive(Clone, Debug, PartialEq, Eq)]
30pub struct DirectJsonSchema {
31    pub name: String,
32    pub schema: serde_json::Value,
33    pub strict: bool,
34}
35
36#[derive(Clone, Debug, PartialEq, Eq, Default)]
37pub enum DirectOutputSpec {
38    #[default]
39    Text,
40    JsonObject,
41    JsonSchema(DirectJsonSchema),
42}
43
44#[derive(Clone, Debug)]
45pub struct DirectRequest {
46    pub model: String,
47    pub model_variant: Option<String>,
48    pub messages: Vec<DirectMessage>,
49    pub attachments: Vec<LlmAttachment>,
50    pub output: DirectOutputSpec,
51    pub stream_events: Option<LlmEventSender>,
52    pub session_id: Option<String>,
53    /// Set this when issuing a direct completion from inside a tool's
54    /// `execute`. Carries the calling tool's call id all the way to the
55    /// trace event so the renderer can group fan-outs (e.g.
56    /// tournament_rerank's batch reranks) under their parent.
57    pub originating_tool_call_id: Option<String>,
58}
59
60impl DirectRequest {
61    pub fn text(model: impl Into<String>, prompt: impl Into<String>) -> Self {
62        Self {
63            model: model.into(),
64            model_variant: None,
65            messages: vec![DirectMessage {
66                role: DirectRole::User,
67                parts: vec![DirectPart::Text(prompt.into())],
68            }],
69            attachments: Vec::new(),
70            output: DirectOutputSpec::Text,
71            stream_events: None,
72            session_id: None,
73            originating_tool_call_id: None,
74        }
75    }
76
77    pub fn json(model: impl Into<String>, prompt: impl Into<String>) -> Self {
78        Self {
79            output: DirectOutputSpec::JsonObject,
80            ..Self::text(model, prompt)
81        }
82    }
83
84    pub fn json_schema(
85        model: impl Into<String>,
86        prompt: impl Into<String>,
87        schema: DirectJsonSchema,
88    ) -> Self {
89        Self {
90            output: DirectOutputSpec::JsonSchema(schema),
91            ..Self::text(model, prompt)
92        }
93    }
94}
95
96#[derive(Debug, thiserror::Error, Clone)]
97pub enum DirectLlmError {
98    #[error("invalid request: {0}")]
99    InvalidRequest(String),
100    #[error("transport error: {0}")]
101    Transport(#[from] LlmTransportError),
102}
103
104pub struct DirectLlmClient {
105    provider: ProviderHandle,
106    trace_sink: Option<Arc<dyn TraceSink>>,
107    trace_context: TraceContext,
108}
109
110impl DirectLlmClient {
111    pub fn new(provider: ProviderHandle) -> Self {
112        Self {
113            provider,
114            trace_sink: None,
115            trace_context: TraceContext::default(),
116        }
117    }
118
119    pub fn with_trace_sink(mut self, sink: Option<Arc<dyn TraceSink>>) -> Self {
120        self.trace_sink = sink;
121        self
122    }
123
124    pub fn with_trace_context(mut self, context: TraceContext) -> Self {
125        self.trace_context = context;
126        self
127    }
128
129    pub fn provider(&self) -> &ProviderHandle {
130        &self.provider
131    }
132
133    pub fn provider_mut(&mut self) -> &mut ProviderHandle {
134        &mut self.provider
135    }
136
137    pub async fn complete(
138        &mut self,
139        request: DirectRequest,
140    ) -> Result<LlmResponse, DirectLlmError> {
141        let normalized_model = self.provider.resolve_model(&request.model);
142        if let Some(variant) = request.model_variant.as_deref() {
143            self.provider
144                .validate_variant(&normalized_model, variant)
145                .map_err(DirectLlmError::InvalidRequest)?;
146        }
147
148        let llm_request = build_llm_request(&self.provider, request, normalized_model);
149        let llm_call_id = if self.trace_sink.is_some() {
150            let id = uuid::Uuid::new_v4().to_string();
151            crate::trace::emit_trace(
152                &self.trace_sink,
153                &self.trace_context,
154                TraceContext::default().for_llm_call(id.clone()),
155                TraceEvent::LlmCallStarted {
156                    request: crate::trace::trace_llm_request(&llm_request),
157                },
158            );
159            Some(id)
160        } else {
161            None
162        };
163        match self.provider.complete(llm_request).await {
164            Ok(response) => {
165                if let Some(llm_call_id) = llm_call_id {
166                    crate::trace::emit_trace(
167                        &self.trace_sink,
168                        &self.trace_context,
169                        TraceContext::default().for_llm_call(llm_call_id),
170                        TraceEvent::LlmCallCompleted {
171                            response: crate::trace::trace_llm_response(
172                                response.full_text.clone(),
173                                0,
174                                Some(response.terminal_reason),
175                                crate::trace::trace_output_parts(&response.parts),
176                            ),
177                            usage: Some(crate::trace::trace_usage_from_llm(&response.usage)),
178                            provider_usage: response.provider_usage.clone(),
179                            stream_summary: None,
180                        },
181                    );
182                }
183                Ok(response)
184            }
185            Err(error) => {
186                if let Some(llm_call_id) = llm_call_id {
187                    crate::trace::emit_trace(
188                        &self.trace_sink,
189                        &self.trace_context,
190                        TraceContext::default().for_llm_call(llm_call_id),
191                        TraceEvent::LlmCallFailed {
192                            error: TraceError {
193                                message: error.message.clone(),
194                                retryable: error.retryable,
195                                terminal_reason: Some(error.terminal_reason.code().to_string()),
196                                code: error.code.clone(),
197                                raw: error.raw.clone(),
198                            },
199                            stream_summary: None,
200                        },
201                    );
202                }
203                Err(DirectLlmError::from(error))
204            }
205        }
206    }
207}
208
209pub(crate) fn build_llm_request(
210    provider: &ProviderHandle,
211    request: DirectRequest,
212    model: String,
213) -> LlmRequest {
214    let stream_events = transport_stream_events_for_direct(provider, request.stream_events);
215    let DirectRequest {
216        model: _,
217        model_variant,
218        messages,
219        attachments,
220        output,
221        stream_events: _,
222        session_id,
223        originating_tool_call_id: _,
224    } = request;
225
226    let output_spec = match output {
227        DirectOutputSpec::Text => None,
228        DirectOutputSpec::JsonObject => Some(LlmOutputSpec::JsonObject),
229        DirectOutputSpec::JsonSchema(schema) => Some(LlmOutputSpec::JsonSchema(LlmJsonSchema {
230            name: schema.name,
231            schema: schema.schema,
232            strict: schema.strict,
233        })),
234    };
235
236    let mut llm_messages = Vec::new();
237    for message in messages {
238        let role = match message.role {
239            DirectRole::System => LlmRole::System,
240            DirectRole::User => LlmRole::User,
241            DirectRole::Assistant => LlmRole::Assistant,
242        };
243        let mut blocks: Vec<LlmContentBlock> = Vec::new();
244        for part in message.parts {
245            match part {
246                DirectPart::Text(text) => {
247                    if !text.is_empty() {
248                        blocks.push(LlmContentBlock::Text {
249                            text: text.into(),
250                            response_meta: None,
251                            cache_breakpoint: false,
252                        });
253                    }
254                }
255                DirectPart::Image(idx) => {
256                    blocks.push(LlmContentBlock::Image {
257                        attachment_idx: idx,
258                    });
259                }
260            }
261        }
262        if !blocks.is_empty() {
263            llm_messages.push(LlmMessage::new(role, blocks));
264        }
265    }
266
267    LlmRequest {
268        model,
269        messages: llm_messages,
270        attachments,
271        tools: Vec::new().into(),
272        tool_choice: LlmToolChoice::None,
273        model_variant,
274        session_id,
275        output_spec,
276        stream_events,
277        provider_trace: None,
278    }
279}
280
281fn transport_stream_events_for_direct(
282    provider: &ProviderHandle,
283    requested: Option<LlmEventSender>,
284) -> Option<LlmEventSender> {
285    if requested.is_some() {
286        return requested;
287    }
288    if provider.requires_streaming() {
289        Some(LlmEventSender::new(|_event: LlmStreamEvent| {}))
290    } else {
291        None
292    }
293}