1use async_trait::async_trait;
4use futures::stream::{self, Stream};
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use std::collections::HashMap;
7use std::pin::Pin;
8use thiserror::Error;
9
10#[derive(Error, Debug)]
12pub enum ProviderError {
13 #[error("HTTP request failed: {0}")]
14 HttpError(#[from] reqwest::Error),
15
16 #[error("JSON parsing failed: {0}")]
17 JsonError(#[from] serde_json::Error),
18
19 #[error("Invalid response: {0}")]
20 InvalidResponse(String),
21
22 #[error("API error: {0}")]
23 ApiError(String),
24
25 #[error("Configuration error: {0}")]
26 ConfigError(String),
27}
28
29pub type ProviderResult<T> = Result<T, ProviderError>;
30
31pub type ProviderEventStream = Pin<Box<dyn Stream<Item = ProviderResult<LLMStreamEvent>> + Send>>;
32
33#[derive(Debug, Clone)]
35pub struct ToolCallRequest {
36 pub id: String,
37 pub call_type: String,
38 pub name: String,
39 pub arguments: HashMap<String, serde_json::Value>,
40}
41
42impl Serialize for ToolCallRequest {
43 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
44 where
45 S: Serializer,
46 {
47 use serde::ser::Error as _;
48 use serde::ser::SerializeStruct;
49
50 #[derive(Serialize)]
51 struct Function<'a> {
52 name: &'a str,
53 arguments: String,
54 }
55
56 let arguments = serde_json::to_string(&self.arguments).map_err(|e| {
57 S::Error::custom(format!(
58 "failed to serialize tool call arguments for {}: {}",
59 self.name, e
60 ))
61 })?;
62
63 let mut state = serializer.serialize_struct("ToolCallRequest", 3)?;
64 state.serialize_field("id", &self.id)?;
65 state.serialize_field("type", &self.call_type)?;
66 state.serialize_field(
67 "function",
68 &Function {
69 name: &self.name,
70 arguments,
71 },
72 )?;
73 state.end()
74 }
75}
76
77impl<'de> Deserialize<'de> for ToolCallRequest {
78 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
79 where
80 D: Deserializer<'de>,
81 {
82 #[derive(Deserialize)]
83 struct Function {
84 name: String,
85 arguments: serde_json::Value,
86 }
87
88 #[derive(Deserialize)]
89 struct Helper {
90 id: String,
91 #[serde(rename = "type")]
92 call_type: String,
93 #[serde(default)]
94 function: Option<Function>,
95 #[serde(default)]
96 name: Option<String>,
97 #[serde(default)]
98 arguments: Option<serde_json::Value>,
99 }
100
101 fn normalize_arguments(value: serde_json::Value) -> HashMap<String, serde_json::Value> {
102 match value {
103 serde_json::Value::String(raw) => serde_json::from_str::<
104 HashMap<String, serde_json::Value>,
105 >(&raw)
106 .unwrap_or_else(|_| {
107 let mut map = HashMap::new();
108 map.insert("raw".to_string(), serde_json::Value::String(raw));
109 map
110 }),
111 serde_json::Value::Object(map) => map.into_iter().collect(),
112 _ => HashMap::new(),
113 }
114 }
115
116 let helper = Helper::deserialize(deserializer)?;
117 if let Some(function) = helper.function {
118 return Ok(Self {
119 id: helper.id,
120 call_type: helper.call_type,
121 name: function.name,
122 arguments: normalize_arguments(function.arguments),
123 });
124 }
125
126 let name = helper
127 .name
128 .ok_or_else(|| serde::de::Error::missing_field("function or name"))?;
129 let arguments = helper
130 .arguments
131 .map(normalize_arguments)
132 .unwrap_or_default();
133
134 Ok(Self {
135 id: helper.id,
136 call_type: helper.call_type,
137 name,
138 arguments,
139 })
140 }
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct LLMResponse {
146 pub content: Option<String>,
147 #[serde(default)]
148 pub tool_calls: Vec<ToolCallRequest>,
149 #[serde(default = "default_finish_reason")]
150 pub finish_reason: String,
151 #[serde(default)]
152 pub usage: HashMap<String, i64>,
153 #[serde(default)]
154 pub reasoning_content: Option<String>,
155}
156
157fn default_finish_reason() -> String {
158 "stop".to_string()
159}
160
161impl LLMResponse {
162 pub fn has_tool_calls(&self) -> bool {
164 !self.tool_calls.is_empty()
165 }
166}
167
168#[derive(Debug, Clone, Serialize, Deserialize)]
170pub enum LLMStreamEvent {
171 TextDelta(String),
173 ReasoningDelta(String),
175 ToolCallDelta {
177 index: usize,
178 id: Option<String>,
179 name: Option<String>,
180 arguments_delta: Option<String>,
181 },
182 Completed(LLMResponse),
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct Message {
189 pub role: String,
190 pub content: String,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 pub name: Option<String>,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 pub tool_call_id: Option<String>,
195 #[serde(skip_serializing_if = "Option::is_none")]
196 pub tool_calls: Option<Vec<ToolCallRequest>>,
197 #[serde(skip_serializing_if = "Option::is_none")]
198 pub reasoning_content: Option<String>,
199 #[serde(skip_serializing_if = "Option::is_none")]
200 pub thinking_blocks: Option<Vec<serde_json::Value>>,
201}
202
203impl Message {
204 pub fn user(content: impl Into<String>) -> Self {
206 Self {
207 role: "user".to_string(),
208 content: content.into(),
209 name: None,
210 tool_call_id: None,
211 tool_calls: None,
212 reasoning_content: None,
213 thinking_blocks: None,
214 }
215 }
216
217 pub fn system(content: impl Into<String>) -> Self {
219 Self {
220 role: "system".to_string(),
221 content: content.into(),
222 name: None,
223 tool_call_id: None,
224 tool_calls: None,
225 reasoning_content: None,
226 thinking_blocks: None,
227 }
228 }
229
230 pub fn assistant(content: impl Into<String>) -> Self {
232 Self {
233 role: "assistant".to_string(),
234 content: content.into(),
235 name: None,
236 tool_call_id: None,
237 tool_calls: None,
238 reasoning_content: None,
239 thinking_blocks: None,
240 }
241 }
242
243 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
245 Self {
246 role: "tool".to_string(),
247 content: content.into(),
248 name: None,
249 tool_call_id: Some(tool_call_id.into()),
250 tool_calls: None,
251 reasoning_content: None,
252 thinking_blocks: None,
253 }
254 }
255}
256
257#[async_trait]
259pub trait LLMProvider: Send + Sync {
260 async fn chat(
262 &self,
263 messages: Vec<Message>,
264 tools: Option<Vec<serde_json::Value>>,
265 model: Option<String>,
266 max_tokens: i32,
267 temperature: f64,
268 ) -> ProviderResult<LLMResponse>;
269
270 async fn chat_stream(
274 &self,
275 messages: Vec<Message>,
276 tools: Option<Vec<serde_json::Value>>,
277 model: Option<String>,
278 max_tokens: i32,
279 temperature: f64,
280 ) -> ProviderResult<ProviderEventStream> {
281 let response = self
282 .chat(messages, tools, model, max_tokens, temperature)
283 .await?;
284
285 let mut events = Vec::new();
286 if let Some(content) = response.content.clone() {
287 if !content.is_empty() {
288 events.push(Ok(LLMStreamEvent::TextDelta(content)));
289 }
290 }
291 events.push(Ok(LLMStreamEvent::Completed(response)));
292
293 Ok(Box::pin(stream::iter(events)))
294 }
295
296 fn get_default_model(&self) -> String;
298}