hefa_core/llm/
mod.rs

1use async_trait::async_trait;
2use reqwest::Client;
3#[cfg(feature = "schema")]
4use schemars::{JsonSchema, schema_for};
5use serde::{Deserialize, Serialize};
6use serde_json::{Value, json};
7use thiserror::Error;
8
9use crate::config::{ConfigError, ProviderConfig, ProviderKind, ProviderResolver};
10
11/// Minimal representation of an LLM chat message.
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
13pub struct ChatMessage {
14    pub role: MessageRole,
15    pub content: String,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub tool_call_id: Option<String>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(rename_all = "lowercase")]
22pub enum MessageRole {
23    System,
24    User,
25    Assistant,
26    Tool,
27}
28
29/// Optional structured output schema payload.
30#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
31pub struct StructuredOutput {
32    pub schema: serde_json::Value,
33}
34
35impl StructuredOutput {
36    pub fn new(schema: serde_json::Value) -> Self {
37        Self { schema }
38    }
39
40    #[cfg(feature = "schema")]
41    pub fn from_type<T: JsonSchema>() -> Self {
42        let root = schema_for!(T);
43        let value =
44            serde_json::to_value(root.schema).expect("schemars schema should serialize to JSON");
45        Self { schema: value }
46    }
47}
48
49/// Tool schema provided to LLMs for function calling.
50#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
51pub struct ToolDefinition {
52    pub name: String,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub description: Option<String>,
55    pub json_schema: serde_json::Value,
56}
57
58/// Request envelope passed to LLM clients.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct LlmRequest {
61    pub messages: Vec<ChatMessage>,
62    pub structured_output: Option<StructuredOutput>,
63    #[serde(default)]
64    pub tools: Vec<ToolDefinition>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub tool_choice: Option<String>,
67}
68
69/// Result of invoking an LLM.
70#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
71pub struct LlmResponse {
72    pub content: String,
73    pub tool_calls: Vec<ToolCall>,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
77pub struct ToolCall {
78    pub id: String,
79    pub name: String,
80    pub arguments: serde_json::Value,
81}
82
83/// Common error contract for LLM invocations.
84#[derive(Debug, Error)]
85pub enum LlmError {
86    #[error("configuration error: {0}")]
87    Config(#[from] crate::config::ConfigError),
88    #[error("network error: {0}")]
89    Network(String),
90    #[error("serialization error: {0}")]
91    Serialization(String),
92    #[error("unsupported operation: {0}")]
93    Unsupported(String),
94}
95
96#[async_trait]
97pub trait LlmClient: Send + Sync {
98    async fn invoke(&self, request: LlmRequest) -> Result<LlmResponse, LlmError>;
99}
100
101/// HTTP-based client dedicated to a specific provider configuration.
102pub struct LLMClient {
103    http: Client,
104    provider: ProviderConfig,
105}
106
107impl LLMClient {
108    pub fn new(provider: ProviderConfig) -> Self {
109        Self {
110            http: Client::new(),
111            provider,
112        }
113    }
114
115    pub fn with_client(http: Client, provider: ProviderConfig) -> Self {
116        Self { http, provider }
117    }
118
119    pub fn from_env(kind: ProviderKind, model: &str) -> Result<Self, ConfigError> {
120        let resolver = ProviderResolver::from_process();
121        let provider = resolver.resolve_with_kind(kind, model)?;
122        Ok(Self::new(provider))
123    }
124
125    async fn invoke_openai(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError> {
126        let base = self.provider.base_url.trim_end_matches('/');
127        let url = format!("{}/responses", base);
128        let payload = OpenAiPayload::from_request(request, &self.provider);
129        let mut builder = self.http.post(url).json(&payload);
130        if let Some(key) = &self.provider.api_key {
131            builder = builder.bearer_auth(key);
132        } else {
133            return Err(LlmError::Config(crate::config::ConfigError::MissingEnv(
134                ProviderKind::OpenAi,
135            )));
136        }
137        let resp = builder
138            .send()
139            .await
140            .map_err(|e| LlmError::Network(e.to_string()))?;
141        let status = resp.status();
142        let bytes = resp
143            .bytes()
144            .await
145            .map_err(|e| LlmError::Network(e.to_string()))?;
146        if !status.is_success() {
147            return Err(LlmError::Network(format!(
148                "OpenAI request failed with {}: {}",
149                status,
150                String::from_utf8_lossy(&bytes)
151            )));
152        }
153        let parsed: OpenAiResponse =
154            serde_json::from_slice(&bytes).map_err(|e| LlmError::Serialization(e.to_string()))?;
155        Ok(parsed.into_response()?)
156    }
157
158    async fn invoke_compat(&self, request: &LlmRequest) -> Result<LlmResponse, LlmError> {
159        let base = self.provider.base_url.trim_end_matches('/');
160        let url = format!("{}/chat/completions", base);
161        let payload = CompatPayload::from_request(request, &self.provider);
162        let mut builder = self.http.post(url).json(&payload);
163        if let Some(key) = &self.provider.api_key {
164            builder = builder.bearer_auth(key);
165        }
166        let resp = builder
167            .send()
168            .await
169            .map_err(|e| LlmError::Network(e.to_string()))?;
170        let status = resp.status();
171        let bytes = resp
172            .bytes()
173            .await
174            .map_err(|e| LlmError::Network(e.to_string()))?;
175        if !status.is_success() {
176            return Err(LlmError::Network(format!(
177                "Compat request failed with {}: {}",
178                status,
179                String::from_utf8_lossy(&bytes)
180            )));
181        }
182        let parsed: CompatResponse =
183            serde_json::from_slice(&bytes).map_err(|e| LlmError::Serialization(e.to_string()))?;
184        Ok(parsed.into_response()?)
185    }
186}
187
188#[async_trait]
189impl LlmClient for LLMClient {
190    async fn invoke(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
191        match self.provider.kind {
192            ProviderKind::OpenAi => self.invoke_openai(&request).await,
193            ProviderKind::Ollama | ProviderKind::LmStudio => self.invoke_compat(&request).await,
194        }
195    }
196}
197
198#[derive(Serialize)]
199struct OpenAiPayload<'a> {
200    model: &'a str,
201    input: Vec<ResponseInput<'a>>,
202    #[serde(skip_serializing_if = "Option::is_none")]
203    response_format: Option<ResponseFormat>,
204    #[serde(skip_serializing_if = "Vec::is_empty", default)]
205    tools: Vec<ResponseTool<'a>>,
206    #[serde(skip_serializing_if = "Option::is_none")]
207    tool_choice: Option<Value>,
208}
209
210impl<'a> OpenAiPayload<'a> {
211    fn from_request(req: &'a LlmRequest, provider: &'a ProviderConfig) -> Self {
212        let input = req
213            .messages
214            .iter()
215            .map(ResponseInput::from)
216            .collect::<Vec<_>>();
217        let response_format = req.structured_output.as_ref().map(|schema| ResponseFormat {
218            r#type: "json_schema",
219            json_schema: json!({
220                "name": "structured_output",
221                "schema": schema.schema.clone()
222            }),
223        });
224        let tools = req.tools.iter().map(ResponseTool::from).collect();
225        let tool_choice = req
226            .tool_choice
227            .as_ref()
228            .map(|choice| json!({ "type": choice }));
229        Self {
230            model: &provider.model,
231            input,
232            response_format,
233            tools,
234            tool_choice,
235        }
236    }
237}
238
239#[derive(Serialize)]
240struct ResponseInput<'a> {
241    role: &'a MessageRole,
242    content: Vec<ResponseContent<'a>>,
243}
244
245impl<'a> From<&'a ChatMessage> for ResponseInput<'a> {
246    fn from(msg: &'a ChatMessage) -> Self {
247        let content = vec![ResponseContent {
248            r#type: "text",
249            text: &msg.content,
250        }];
251        Self {
252            role: &msg.role,
253            content,
254        }
255    }
256}
257
258#[derive(Serialize)]
259struct ResponseContent<'a> {
260    r#type: &'static str,
261    text: &'a str,
262}
263
264#[derive(Serialize)]
265struct ResponseFormat {
266    r#type: &'static str,
267    json_schema: Value,
268}
269
270#[derive(Serialize)]
271struct ResponseTool<'a> {
272    r#type: &'static str,
273    function: ResponseFunction<'a>,
274}
275
276impl<'a> From<&'a ToolDefinition> for ResponseTool<'a> {
277    fn from(tool: &'a ToolDefinition) -> Self {
278        Self {
279            r#type: "function",
280            function: ResponseFunction {
281                name: &tool.name,
282                description: tool.description.as_deref(),
283                parameters: &tool.json_schema,
284            },
285        }
286    }
287}
288
289#[derive(Serialize)]
290struct ResponseFunction<'a> {
291    name: &'a str,
292    #[serde(skip_serializing_if = "Option::is_none")]
293    description: Option<&'a str>,
294    parameters: &'a serde_json::Value,
295}
296
297#[derive(Deserialize)]
298struct OpenAiResponse {
299    output: Vec<ResponseOutput>,
300}
301
302impl OpenAiResponse {
303    fn into_response(self) -> Result<LlmResponse, LlmError> {
304        let mut text = String::new();
305        let mut tool_calls = Vec::new();
306        for item in self.output {
307            for content in item.content {
308                match content {
309                    OutputContent::OutputText { text: chunk } => {
310                        text.push_str(&chunk.text);
311                    }
312                    OutputContent::ToolCalls { tool_calls: calls } => {
313                        for call in calls {
314                            let arguments = serde_json::from_str::<Value>(&call.function.arguments)
315                                .map_err(|e| LlmError::Serialization(e.to_string()))?;
316                            tool_calls.push(ToolCall {
317                                id: call.id,
318                                name: call.function.name,
319                                arguments,
320                            });
321                        }
322                    }
323                }
324            }
325        }
326        Ok(LlmResponse {
327            content: text,
328            tool_calls,
329        })
330    }
331}
332
333#[derive(Deserialize)]
334struct ResponseOutput {
335    content: Vec<OutputContent>,
336}
337
338#[derive(Deserialize)]
339#[serde(tag = "type", rename_all = "snake_case")]
340enum OutputContent {
341    OutputText { text: TextChunk },
342    ToolCalls { tool_calls: Vec<OpenAiToolCall> },
343}
344
345#[derive(Deserialize)]
346struct TextChunk {
347    text: String,
348}
349
350#[derive(Deserialize)]
351struct OpenAiToolCall {
352    id: String,
353    function: OpenAiFunctionCall,
354}
355
356#[derive(Deserialize)]
357struct OpenAiFunctionCall {
358    name: String,
359    arguments: String,
360}
361
362#[derive(Serialize)]
363struct CompatPayload<'a> {
364    model: &'a str,
365    messages: Vec<CompatMessage<'a>>,
366    #[serde(skip_serializing_if = "Vec::is_empty", default)]
367    tools: Vec<ResponseTool<'a>>,
368    #[serde(skip_serializing_if = "Option::is_none")]
369    tool_choice: Option<Value>,
370    #[serde(skip_serializing_if = "Option::is_none")]
371    response_format: Option<ResponseFormat>,
372}
373
374impl<'a> CompatPayload<'a> {
375    fn from_request(req: &'a LlmRequest, provider: &'a ProviderConfig) -> Self {
376        let messages = req.messages.iter().map(CompatMessage::from).collect();
377        let tools = req.tools.iter().map(ResponseTool::from).collect();
378        let tool_choice = req
379            .tool_choice
380            .as_ref()
381            .map(|choice| json!({ "type": choice }));
382        let response_format = req.structured_output.as_ref().map(|schema| ResponseFormat {
383            r#type: "json_schema",
384            json_schema: json!({
385                "name": "structured_output",
386                "schema": schema.schema.clone()
387            }),
388        });
389        Self {
390            model: &provider.model,
391            messages,
392            tools,
393            tool_choice,
394            response_format,
395        }
396    }
397}
398
399#[derive(Serialize)]
400struct CompatMessage<'a> {
401    role: &'a MessageRole,
402    content: &'a str,
403    #[serde(skip_serializing_if = "Option::is_none")]
404    tool_call_id: Option<&'a String>,
405}
406
407impl<'a> From<&'a ChatMessage> for CompatMessage<'a> {
408    fn from(msg: &'a ChatMessage) -> Self {
409        Self {
410            role: &msg.role,
411            content: &msg.content,
412            tool_call_id: msg.tool_call_id.as_ref(),
413        }
414    }
415}
416
417#[derive(Deserialize)]
418struct CompatResponse {
419    choices: Vec<CompatChoice>,
420}
421
422impl CompatResponse {
423    fn into_response(self) -> Result<LlmResponse, LlmError> {
424        let choice = self.choices.into_iter().next().ok_or_else(|| {
425            LlmError::Serialization("chat completion did not return choices".into())
426        })?;
427        let mut tool_calls = Vec::new();
428        if let Some(calls) = choice.message.tool_calls {
429            for call in calls {
430                let arguments = serde_json::from_str::<Value>(&call.function.arguments)
431                    .map_err(|e| LlmError::Serialization(e.to_string()))?;
432                tool_calls.push(ToolCall {
433                    id: call.id,
434                    name: call.function.name,
435                    arguments,
436                });
437            }
438        }
439        Ok(LlmResponse {
440            content: choice.message.content.unwrap_or_default(),
441            tool_calls,
442        })
443    }
444}
445
446#[derive(Deserialize)]
447struct CompatChoice {
448    message: CompatChoiceMessage,
449}
450
451#[derive(Deserialize)]
452struct CompatChoiceMessage {
453    content: Option<String>,
454    #[serde(default)]
455    tool_calls: Option<Vec<OpenAiToolCall>>,
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    struct EchoLlm;
463
464    #[async_trait]
465    impl LlmClient for EchoLlm {
466        async fn invoke(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
467            let last = request.messages.last().cloned().unwrap_or(ChatMessage {
468                role: MessageRole::System,
469                content: String::new(),
470                tool_call_id: None,
471            });
472            Ok(LlmResponse {
473                content: last.content,
474                tool_calls: vec![],
475            })
476        }
477    }
478
479    #[tokio::test]
480    async fn echo_llm_returns_last_message() {
481        let llm = EchoLlm;
482        let request = LlmRequest {
483            messages: vec![
484                ChatMessage {
485                    role: MessageRole::System,
486                    content: "rule".into(),
487                    tool_call_id: None,
488                },
489                ChatMessage {
490                    role: MessageRole::User,
491                    content: "hello".into(),
492                    tool_call_id: None,
493                },
494            ],
495            structured_output: None,
496            tools: vec![],
497            tool_choice: None,
498        };
499        let response = llm.invoke(request).await.unwrap();
500        assert_eq!(response.content, "hello");
501        assert!(response.tool_calls.is_empty());
502    }
503
504    #[test]
505    fn openai_response_parses_tool_calls() {
506        let response = OpenAiResponse {
507            output: vec![ResponseOutput {
508                content: vec![OutputContent::ToolCalls {
509                    tool_calls: vec![OpenAiToolCall {
510                        id: "tool_1".into(),
511                        function: OpenAiFunctionCall {
512                            name: "echo".into(),
513                            arguments: "{\"message\": \"hello\"}".into(),
514                        },
515                    }],
516                }],
517            }],
518        };
519        let resp = response.into_response().expect("parse");
520        assert_eq!(resp.tool_calls.len(), 1);
521        assert_eq!(resp.tool_calls[0].name, "echo");
522        assert_eq!(resp.tool_calls[0].arguments, json!({"message": "hello"}));
523    }
524
525    #[test]
526    fn compat_response_parses_text() {
527        let response = CompatResponse {
528            choices: vec![CompatChoice {
529                message: CompatChoiceMessage {
530                    content: Some("hello world".into()),
531                    tool_calls: None,
532                },
533            }],
534        };
535        let resp = response.into_response().expect("parse");
536        assert_eq!(resp.content, "hello world");
537    }
538
539    #[test]
540    fn compat_response_parses_tool_calls() {
541        let response = CompatResponse {
542            choices: vec![CompatChoice {
543                message: CompatChoiceMessage {
544                    content: None,
545                    tool_calls: Some(vec![OpenAiToolCall {
546                        id: "tool_2".into(),
547                        function: OpenAiFunctionCall {
548                            name: "search".into(),
549                            arguments: "{\"query\": \"rust\"}".into(),
550                        },
551                    }]),
552                },
553            }],
554        };
555        let resp = response.into_response().expect("parse");
556        assert_eq!(resp.tool_calls.len(), 1);
557        assert_eq!(resp.tool_calls[0].name, "search");
558    }
559}