1use async_trait::async_trait;
2use reqwest::Client;
3#[cfg(feature = "schema")]
4use schemars::{JsonSchema, schema_for};
5use serde::{Deserialize, Serialize};
6use serde_json::{Value, json};
7use thiserror::Error;
8
9use crate::config::{ConfigError, ProviderConfig, ProviderKind, ProviderResolver};
10
11#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct ChatMessage {
14 pub role: MessageRole,
15 pub content: String,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub tool_call_id: Option<String>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(rename_all = "lowercase")]
22pub enum MessageRole {
23 System,
24 User,
25 Assistant,
26 Tool,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub struct StructuredOutput {
32 pub schema: serde_json::Value,
33}
34
35impl StructuredOutput {
36 pub fn new(schema: serde_json::Value) -> Self {
37 Self { schema }
38 }
39
40 #[cfg(feature = "schema")]
41 pub fn from_type<T: JsonSchema>() -> Self {
42 let root = schema_for!(T);
43 let value =
44 serde_json::to_value(root.schema).expect("schemars schema should serialize to JSON");
45 Self { schema: value }
46 }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
51pub struct ToolDefinition {
52 pub name: String,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub description: Option<String>,
55 pub json_schema: serde_json::Value,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct LlmRequest {
61 pub messages: Vec<ChatMessage>,
62 pub structured_output: Option<StructuredOutput>,
63 #[serde(default)]
64 pub tools: Vec<ToolDefinition>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub tool_choice: Option<String>,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
71pub struct LlmResponse {
72 pub content: String,
73 pub tool_calls: Vec<ToolCall>,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
77pub struct ToolCall {
78 pub id: String,
79 pub name: String,
80 pub arguments: serde_json::Value,
81}
82
83#[derive(Debug, Error)]
85pub enum LlmError {
86 #[error("configuration error: {0}")]
87 Config(#[from] crate::config::ConfigError),
88 #[error("network error: {0}")]
89 Network(String),
90 #[error("serialization error: {0}")]
91 Serialization(String),
92 #[error("unsupported operation: {0}")]
93 Unsupported(String),
94}
95
96#[async_trait]
97pub trait LlmClient: Send + Sync {
98 async fn invoke(&self, request: LlmRequest) -> Result<LlmResponse, LlmError>;
99}
100
101pub struct LLMClient {
103 http: Client,
104 provider: ProviderConfig,
105}
106
107impl LLMClient {
108 pub fn new(provider: ProviderConfig) -> Self {
109 Self {
110 http: Client::new(),
111 provider,
112 }
113 }
114
115 pub fn with_client(http: Client, provider: ProviderConfig) -> Self {
116 Self { http, provider }
117 }
118
119 pub fn from_env(kind: ProviderKind, model: &str) -> Result<Self, ConfigError> {
120 let resolver = ProviderResolver::from_process();
121 let provider = resolver.resolve_with_kind(kind, model)?;
122 Ok(Self::new(provider))
123 }
124
125 async fn invoke_openai(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError> {
126 let base = self.provider.base_url.trim_end_matches('/');
127 let url = format!("{}/responses", base);
128 let payload = OpenAiPayload::from_request(request, &self.provider);
129 let mut builder = self.http.post(url).json(&payload);
130 if let Some(key) = &self.provider.api_key {
131 builder = builder.bearer_auth(key);
132 } else {
133 return Err(LlmError::Config(crate::config::ConfigError::MissingEnv(
134 ProviderKind::OpenAi,
135 )));
136 }
137 let resp = builder
138 .send()
139 .await
140 .map_err(|e| LlmError::Network(e.to_string()))?;
141 let status = resp.status();
142 let bytes = resp
143 .bytes()
144 .await
145 .map_err(|e| LlmError::Network(e.to_string()))?;
146 if !status.is_success() {
147 return Err(LlmError::Network(format!(
148 "OpenAI request failed with {}: {}",
149 status,
150 String::from_utf8_lossy(&bytes)
151 )));
152 }
153 let parsed: OpenAiResponse =
154 serde_json::from_slice(&bytes).map_err(|e| LlmError::Serialization(e.to_string()))?;
155 Ok(parsed.into_response()?)
156 }
157
158 async fn invoke_compat(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError> {
159 let base = self.provider.base_url.trim_end_matches('/');
160 let url = format!("{}/chat/completions", base);
161 let payload = CompatPayload::from_request(request, &self.provider);
162 let mut builder = self.http.post(url).json(&payload);
163 if let Some(key) = &self.provider.api_key {
164 builder = builder.bearer_auth(key);
165 }
166 let resp = builder
167 .send()
168 .await
169 .map_err(|e| LlmError::Network(e.to_string()))?;
170 let status = resp.status();
171 let bytes = resp
172 .bytes()
173 .await
174 .map_err(|e| LlmError::Network(e.to_string()))?;
175 if !status.is_success() {
176 return Err(LlmError::Network(format!(
177 "Compat request failed with {}: {}",
178 status,
179 String::from_utf8_lossy(&bytes)
180 )));
181 }
182 let parsed: CompatResponse =
183 serde_json::from_slice(&bytes).map_err(|e| LlmError::Serialization(e.to_string()))?;
184 Ok(parsed.into_response()?)
185 }
186}
187
188#[async_trait]
189impl LlmClient for LLMClient {
190 async fn invoke(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
191 match self.provider.kind {
192 ProviderKind::OpenAi => self.invoke_openai(&request).await,
193 ProviderKind::Ollama | ProviderKind::LmStudio => self.invoke_compat(&request).await,
194 }
195 }
196}
197
198#[derive(Serialize)]
199struct OpenAiPayload<'a> {
200 model: &'a str,
201 input: Vec<ResponseInput<'a>>,
202 #[serde(skip_serializing_if = "Option::is_none")]
203 response_format: Option<ResponseFormat>,
204 #[serde(skip_serializing_if = "Vec::is_empty", default)]
205 tools: Vec<ResponseTool<'a>>,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 tool_choice: Option<Value>,
208}
209
210impl<'a> OpenAiPayload<'a> {
211 fn from_request(req: &'a LlmRequest, provider: &'a ProviderConfig) -> Self {
212 let input = req
213 .messages
214 .iter()
215 .map(ResponseInput::from)
216 .collect::<Vec<_>>();
217 let response_format = req.structured_output.as_ref().map(|schema| ResponseFormat {
218 r#type: "json_schema",
219 json_schema: json!({
220 "name": "structured_output",
221 "schema": schema.schema.clone()
222 }),
223 });
224 let tools = req.tools.iter().map(ResponseTool::from).collect();
225 let tool_choice = req
226 .tool_choice
227 .as_ref()
228 .map(|choice| json!({ "type": choice }));
229 Self {
230 model: &provider.model,
231 input,
232 response_format,
233 tools,
234 tool_choice,
235 }
236 }
237}
238
239#[derive(Serialize)]
240struct ResponseInput<'a> {
241 role: &'a MessageRole,
242 content: Vec<ResponseContent<'a>>,
243}
244
245impl<'a> From<&'a ChatMessage> for ResponseInput<'a> {
246 fn from(msg: &'a ChatMessage) -> Self {
247 let content = vec![ResponseContent {
248 r#type: "text",
249 text: &msg.content,
250 }];
251 Self {
252 role: &msg.role,
253 content,
254 }
255 }
256}
257
258#[derive(Serialize)]
259struct ResponseContent<'a> {
260 r#type: &'static str,
261 text: &'a str,
262}
263
264#[derive(Serialize)]
265struct ResponseFormat {
266 r#type: &'static str,
267 json_schema: Value,
268}
269
270#[derive(Serialize)]
271struct ResponseTool<'a> {
272 r#type: &'static str,
273 function: ResponseFunction<'a>,
274}
275
276impl<'a> From<&'a ToolDefinition> for ResponseTool<'a> {
277 fn from(tool: &'a ToolDefinition) -> Self {
278 Self {
279 r#type: "function",
280 function: ResponseFunction {
281 name: &tool.name,
282 description: tool.description.as_deref(),
283 parameters: &tool.json_schema,
284 },
285 }
286 }
287}
288
289#[derive(Serialize)]
290struct ResponseFunction<'a> {
291 name: &'a str,
292 #[serde(skip_serializing_if = "Option::is_none")]
293 description: Option<&'a str>,
294 parameters: &'a serde_json::Value,
295}
296
297#[derive(Deserialize)]
298struct OpenAiResponse {
299 output: Vec<ResponseOutput>,
300}
301
302impl OpenAiResponse {
303 fn into_response(self) -> Result<LlmResponse, LlmError> {
304 let mut text = String::new();
305 let mut tool_calls = Vec::new();
306 for item in self.output {
307 for content in item.content {
308 match content {
309 OutputContent::OutputText { text: chunk } => {
310 text.push_str(&chunk.text);
311 }
312 OutputContent::ToolCalls { tool_calls: calls } => {
313 for call in calls {
314 let arguments = serde_json::from_str::<Value>(&call.function.arguments)
315 .map_err(|e| LlmError::Serialization(e.to_string()))?;
316 tool_calls.push(ToolCall {
317 id: call.id,
318 name: call.function.name,
319 arguments,
320 });
321 }
322 }
323 }
324 }
325 }
326 Ok(LlmResponse {
327 content: text,
328 tool_calls,
329 })
330 }
331}
332
333#[derive(Deserialize)]
334struct ResponseOutput {
335 content: Vec<OutputContent>,
336}
337
338#[derive(Deserialize)]
339#[serde(tag = "type", rename_all = "snake_case")]
340enum OutputContent {
341 OutputText { text: TextChunk },
342 ToolCalls { tool_calls: Vec<OpenAiToolCall> },
343}
344
345#[derive(Deserialize)]
346struct TextChunk {
347 text: String,
348}
349
350#[derive(Deserialize)]
351struct OpenAiToolCall {
352 id: String,
353 function: OpenAiFunctionCall,
354}
355
356#[derive(Deserialize)]
357struct OpenAiFunctionCall {
358 name: String,
359 arguments: String,
360}
361
362#[derive(Serialize)]
363struct CompatPayload<'a> {
364 model: &'a str,
365 messages: Vec<CompatMessage<'a>>,
366 #[serde(skip_serializing_if = "Vec::is_empty", default)]
367 tools: Vec<ResponseTool<'a>>,
368 #[serde(skip_serializing_if = "Option::is_none")]
369 tool_choice: Option<Value>,
370 #[serde(skip_serializing_if = "Option::is_none")]
371 response_format: Option<ResponseFormat>,
372}
373
374impl<'a> CompatPayload<'a> {
375 fn from_request(req: &'a LlmRequest, provider: &'a ProviderConfig) -> Self {
376 let messages = req.messages.iter().map(CompatMessage::from).collect();
377 let tools = req.tools.iter().map(ResponseTool::from).collect();
378 let tool_choice = req
379 .tool_choice
380 .as_ref()
381 .map(|choice| json!({ "type": choice }));
382 let response_format = req.structured_output.as_ref().map(|schema| ResponseFormat {
383 r#type: "json_schema",
384 json_schema: json!({
385 "name": "structured_output",
386 "schema": schema.schema.clone()
387 }),
388 });
389 Self {
390 model: &provider.model,
391 messages,
392 tools,
393 tool_choice,
394 response_format,
395 }
396 }
397}
398
399#[derive(Serialize)]
400struct CompatMessage<'a> {
401 role: &'a MessageRole,
402 content: &'a str,
403 #[serde(skip_serializing_if = "Option::is_none")]
404 tool_call_id: Option<&'a String>,
405}
406
407impl<'a> From<&'a ChatMessage> for CompatMessage<'a> {
408 fn from(msg: &'a ChatMessage) -> Self {
409 Self {
410 role: &msg.role,
411 content: &msg.content,
412 tool_call_id: msg.tool_call_id.as_ref(),
413 }
414 }
415}
416
417#[derive(Deserialize)]
418struct CompatResponse {
419 choices: Vec<CompatChoice>,
420}
421
422impl CompatResponse {
423 fn into_response(self) -> Result<LlmResponse, LlmError> {
424 let choice = self.choices.into_iter().next().ok_or_else(|| {
425 LlmError::Serialization("chat completion did not return choices".into())
426 })?;
427 let mut tool_calls = Vec::new();
428 if let Some(calls) = choice.message.tool_calls {
429 for call in calls {
430 let arguments = serde_json::from_str::<Value>(&call.function.arguments)
431 .map_err(|e| LlmError::Serialization(e.to_string()))?;
432 tool_calls.push(ToolCall {
433 id: call.id,
434 name: call.function.name,
435 arguments,
436 });
437 }
438 }
439 Ok(LlmResponse {
440 content: choice.message.content.unwrap_or_default(),
441 tool_calls,
442 })
443 }
444}
445
446#[derive(Deserialize)]
447struct CompatChoice {
448 message: CompatChoiceMessage,
449}
450
451#[derive(Deserialize)]
452struct CompatChoiceMessage {
453 content: Option<String>,
454 #[serde(default)]
455 tool_calls: Option<Vec<OpenAiToolCall>>,
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 struct EchoLlm;
463
464 #[async_trait]
465 impl LlmClient for EchoLlm {
466 async fn invoke(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
467 let last = request.messages.last().cloned().unwrap_or(ChatMessage {
468 role: MessageRole::System,
469 content: String::new(),
470 tool_call_id: None,
471 });
472 Ok(LlmResponse {
473 content: last.content,
474 tool_calls: vec![],
475 })
476 }
477 }
478
479 #[tokio::test]
480 async fn echo_llm_returns_last_message() {
481 let llm = EchoLlm;
482 let request = LlmRequest {
483 messages: vec![
484 ChatMessage {
485 role: MessageRole::System,
486 content: "rule".into(),
487 tool_call_id: None,
488 },
489 ChatMessage {
490 role: MessageRole::User,
491 content: "hello".into(),
492 tool_call_id: None,
493 },
494 ],
495 structured_output: None,
496 tools: vec![],
497 tool_choice: None,
498 };
499 let response = llm.invoke(request).await.unwrap();
500 assert_eq!(response.content, "hello");
501 assert!(response.tool_calls.is_empty());
502 }
503
504 #[test]
505 fn openai_response_parses_tool_calls() {
506 let response = OpenAiResponse {
507 output: vec![ResponseOutput {
508 content: vec![OutputContent::ToolCalls {
509 tool_calls: vec![OpenAiToolCall {
510 id: "tool_1".into(),
511 function: OpenAiFunctionCall {
512 name: "echo".into(),
513 arguments: "{\"message\": \"hello\"}".into(),
514 },
515 }],
516 }],
517 }],
518 };
519 let resp = response.into_response().expect("parse");
520 assert_eq!(resp.tool_calls.len(), 1);
521 assert_eq!(resp.tool_calls[0].name, "echo");
522 assert_eq!(resp.tool_calls[0].arguments, json!({"message": "hello"}));
523 }
524
525 #[test]
526 fn compat_response_parses_text() {
527 let response = CompatResponse {
528 choices: vec![CompatChoice {
529 message: CompatChoiceMessage {
530 content: Some("hello world".into()),
531 tool_calls: None,
532 },
533 }],
534 };
535 let resp = response.into_response().expect("parse");
536 assert_eq!(resp.content, "hello world");
537 }
538
539 #[test]
540 fn compat_response_parses_tool_calls() {
541 let response = CompatResponse {
542 choices: vec![CompatChoice {
543 message: CompatChoiceMessage {
544 content: None,
545 tool_calls: Some(vec![OpenAiToolCall {
546 id: "tool_2".into(),
547 function: OpenAiFunctionCall {
548 name: "search".into(),
549 arguments: "{\"query\": \"rust\"}".into(),
550 },
551 }]),
552 },
553 }],
554 };
555 let resp = response.into_response().expect("parse");
556 assert_eq!(resp.tool_calls.len(), 1);
557 assert_eq!(resp.tool_calls[0].name, "search");
558 }
559}