1use crate::error::{LlmError, Result};
4use crate::tools::ToolCall;
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use super::message::LlmMessage;
10
11#[async_trait]
13pub trait LlmClient: Send + Sync {
14 async fn chat_completion(
16 &self,
17 messages: Vec<LlmMessage>,
18 tools: Option<Vec<ToolDefinition>>,
19 options: Option<ChatOptions>,
20 ) -> Result<LlmResponse>;
21
22 fn model_name(&self) -> &str;
24
25 fn provider_name(&self) -> &str;
27
28 fn supports_streaming(&self) -> bool {
30 false
31 }
32
33 async fn chat_completion_stream(
35 &self,
36 _messages: Vec<LlmMessage>,
37 _tools: Option<Vec<ToolDefinition>>,
38 _options: Option<ChatOptions>,
39 ) -> Result<Box<dyn futures::Stream<Item = Result<LlmStreamChunk>> + Send + Unpin + '_>> {
40 Err((LlmError::InvalidRequest {
41 message: "Streaming not supported by this client".to_string(),
42 })
43 .into())
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct LlmResponse {
50 pub message: LlmMessage,
52
53 pub usage: Option<Usage>,
55
56 pub model: String,
58
59 pub finish_reason: Option<FinishReason>,
61
62 pub metadata: Option<HashMap<String, serde_json::Value>>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct LlmStreamChunk {
69 pub delta: Option<String>,
71
72 pub tool_calls: Option<Vec<ToolCall>>,
74
75 pub finish_reason: Option<FinishReason>,
77
78 pub usage: Option<Usage>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct Usage {
85 pub prompt_tokens: u32,
87
88 pub completion_tokens: u32,
90
91 pub total_tokens: u32,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
97#[serde(rename_all = "snake_case")]
98pub enum FinishReason {
99 Stop,
101
102 Length,
104
105 ToolCalls,
107
108 ContentFilter,
110
111 Other(String),
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct ToolDefinition {
118 #[serde(rename = "type")]
120 pub tool_type: String,
121
122 pub function: FunctionDefinition,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct FunctionDefinition {
129 pub name: String,
131
132 pub description: String,
134
135 pub parameters: serde_json::Value,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct ChatOptions {
142 pub max_tokens: Option<u32>,
144
145 pub temperature: Option<f32>,
147
148 pub top_p: Option<f32>,
150
151 pub top_k: Option<u32>,
153
154 pub stop: Option<Vec<String>>,
156
157 pub stream: Option<bool>,
159
160 pub tool_choice: Option<ToolChoice>,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166#[serde(untagged)]
167pub enum ToolChoice {
168 Auto,
170
171 None,
173
174 Required { name: String },
176}
177
178impl Default for ChatOptions {
179 fn default() -> Self {
180 Self {
181 max_tokens: Some(8192),
182 temperature: Some(0.7),
183 top_p: Some(1.0),
184 top_k: None,
185 stop: None,
186 stream: Some(false),
187 tool_choice: Some(ToolChoice::Auto),
188 }
189 }
190}