1use crate::{LlmClientError, Message, Provider, Result};
4use serde::Serialize;
5
6#[derive(Debug, Clone)]
11pub struct HttpRequest {
12 pub method: &'static str,
14 pub url: String,
16 pub headers: Vec<(String, String)>,
18 pub body: String,
20}
21
22#[derive(Debug, Clone)]
24pub struct RequestBuilder {
25 provider: Provider,
26 model: Option<String>,
27 messages: Vec<Message>,
28 api_key: Option<String>,
29 temperature: Option<f32>,
30 max_tokens: Option<u32>,
31 stream: bool,
32 top_p: Option<f32>,
33 stop: Option<Vec<String>>,
34}
35
36impl RequestBuilder {
37 pub fn new(provider: Provider) -> Self {
39 Self {
40 provider,
41 model: None,
42 messages: Vec::new(),
43 api_key: None,
44 temperature: None,
45 max_tokens: None,
46 stream: false,
47 top_p: None,
48 stop: None,
49 }
50 }
51
52 pub fn model(mut self, model: impl Into<String>) -> Self {
54 self.model = Some(model.into());
55 self
56 }
57
58 pub fn messages(mut self, messages: &[Message]) -> Self {
60 self.messages = messages.to_vec();
61 self
62 }
63
64 pub fn add_message(mut self, message: Message) -> Self {
66 self.messages.push(message);
67 self
68 }
69
70 pub fn api_key(mut self, key: impl Into<String>) -> Self {
72 self.api_key = Some(key.into());
73 self
74 }
75
76 pub fn temperature(mut self, temp: f32) -> Self {
78 self.temperature = Some(temp.clamp(0.0, 2.0));
79 self
80 }
81
82 pub fn max_tokens(mut self, tokens: u32) -> Self {
84 self.max_tokens = Some(tokens);
85 self
86 }
87
88 pub fn stream(mut self, enable: bool) -> Self {
90 self.stream = enable;
91 self
92 }
93
94 pub fn top_p(mut self, p: f32) -> Self {
96 self.top_p = Some(p.clamp(0.0, 1.0));
97 self
98 }
99
100 pub fn stop(mut self, sequences: Vec<String>) -> Self {
102 self.stop = Some(sequences);
103 self
104 }
105
106 pub fn build(&self) -> Result<HttpRequest> {
108 let model = self
109 .model
110 .as_ref()
111 .ok_or_else(|| LlmClientError::missing("model"))?;
112 let api_key = self
113 .api_key
114 .as_ref()
115 .ok_or_else(|| LlmClientError::missing("api_key"))?;
116
117 if self.messages.is_empty() {
118 return Err(LlmClientError::missing("messages"));
119 }
120
121 let url = self.provider.endpoint().to_string();
122 let headers = self.build_headers(api_key);
123 let body = self.build_body(model)?;
124
125 Ok(HttpRequest {
126 method: "POST",
127 url,
128 headers,
129 body,
130 })
131 }
132
133 fn build_headers(&self, api_key: &str) -> Vec<(String, String)> {
134 let mut headers = vec![
135 ("Content-Type".to_string(), "application/json".to_string()),
136 (
137 self.provider.auth_header().to_string(),
138 self.provider.format_auth(api_key),
139 ),
140 ];
141
142 for (key, value) in self.provider.extra_headers() {
143 headers.push((key.to_string(), value.to_string()));
144 }
145
146 headers
147 }
148
149 fn build_body(&self, model: &str) -> Result<String> {
150 match self.provider {
151 Provider::OpenAI | Provider::OpenRouter => self.build_openai_body(model),
152 Provider::Anthropic => self.build_anthropic_body(model),
153 }
154 }
155
156 fn build_openai_body(&self, model: &str) -> Result<String> {
157 #[derive(Serialize)]
158 struct OpenAIRequest<'a> {
159 model: &'a str,
160 messages: &'a [OpenAIMessage<'a>],
161 #[serde(skip_serializing_if = "Option::is_none")]
162 temperature: Option<f32>,
163 #[serde(skip_serializing_if = "Option::is_none")]
164 max_tokens: Option<u32>,
165 stream: bool,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 top_p: Option<f32>,
168 #[serde(skip_serializing_if = "Option::is_none")]
169 stop: Option<&'a [String]>,
170 }
171
172 #[derive(Serialize)]
173 struct OpenAIMessage<'a> {
174 role: &'a str,
175 content: &'a str,
176 }
177
178 let messages: Vec<OpenAIMessage> = self
179 .messages
180 .iter()
181 .map(|m| OpenAIMessage {
182 role: m.role.as_str(),
183 content: &m.content,
184 })
185 .collect();
186
187 let request = OpenAIRequest {
188 model,
189 messages: &messages,
190 temperature: self.temperature,
191 max_tokens: self.max_tokens,
192 stream: self.stream,
193 top_p: self.top_p,
194 stop: self.stop.as_deref(),
195 };
196
197 Ok(serde_json::to_string(&request)?)
198 }
199
200 fn build_anthropic_body(&self, model: &str) -> Result<String> {
201 #[derive(Serialize)]
202 struct AnthropicRequest<'a> {
203 model: &'a str,
204 #[serde(skip_serializing_if = "Option::is_none")]
205 system: Option<&'a str>,
206 messages: Vec<AnthropicMessage<'a>>,
207 max_tokens: u32,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 temperature: Option<f32>,
210 stream: bool,
211 #[serde(skip_serializing_if = "Option::is_none")]
212 top_p: Option<f32>,
213 #[serde(skip_serializing_if = "Option::is_none")]
214 stop_sequences: Option<&'a [String]>,
215 }
216
217 #[derive(Serialize)]
218 struct AnthropicMessage<'a> {
219 role: &'a str,
220 content: &'a str,
221 }
222
223 let system = self
225 .messages
226 .iter()
227 .find(|m| m.role == crate::Role::System)
228 .map(|m| m.content.as_str());
229
230 let messages: Vec<AnthropicMessage> = self
232 .messages
233 .iter()
234 .filter(|m| m.role != crate::Role::System)
235 .map(|m| AnthropicMessage {
236 role: if m.role == crate::Role::User {
237 "user"
238 } else {
239 "assistant"
240 },
241 content: &m.content,
242 })
243 .collect();
244
245 let request = AnthropicRequest {
246 model,
247 system,
248 messages,
249 max_tokens: self.max_tokens.unwrap_or(4096),
250 temperature: self.temperature,
251 stream: self.stream,
252 top_p: self.top_p,
253 stop_sequences: self.stop.as_deref(),
254 };
255
256 Ok(serde_json::to_string(&request)?)
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_openai_request() {
266 let request = RequestBuilder::new(Provider::OpenAI)
267 .model("gpt-4o-mini")
268 .api_key("sk-test")
269 .add_message(Message::system("You are helpful"))
270 .add_message(Message::user("Hello"))
271 .temperature(0.7)
272 .max_tokens(1024)
273 .build()
274 .unwrap();
275
276 assert_eq!(request.method, "POST");
277 assert!(request.url.contains("openai.com"));
278 assert!(request.body.contains("gpt-4o-mini"));
279 assert!(request.body.contains("Hello"));
280
281 let auth_header = request.headers.iter().find(|(k, _)| k == "Authorization");
283 assert!(auth_header.is_some());
284 assert!(auth_header.unwrap().1.starts_with("Bearer "));
285 }
286
287 #[test]
288 fn test_anthropic_request() {
289 let request = RequestBuilder::new(Provider::Anthropic)
290 .model("claude-3-sonnet-20240229")
291 .api_key("sk-ant-test")
292 .add_message(Message::system("You are helpful"))
293 .add_message(Message::user("Hello"))
294 .max_tokens(1024)
295 .build()
296 .unwrap();
297
298 assert!(request.url.contains("anthropic.com"));
299 assert!(request.body.contains("claude-3"));
300 assert!(request.body.contains(r#""system":"You are helpful"#));
302
303 let version_header = request
305 .headers
306 .iter()
307 .find(|(k, _)| k == "anthropic-version");
308 assert!(version_header.is_some());
309 }
310
311 #[test]
312 fn test_missing_model() {
313 let result = RequestBuilder::new(Provider::OpenAI)
314 .api_key("sk-test")
315 .add_message(Message::user("Hello"))
316 .build();
317
318 assert!(result.is_err());
319 assert!(result.unwrap_err().to_string().contains("model"));
320 }
321
322 #[test]
323 fn test_missing_messages() {
324 let result = RequestBuilder::new(Provider::OpenAI)
325 .model("gpt-4")
326 .api_key("sk-test")
327 .build();
328
329 assert!(result.is_err());
330 assert!(result.unwrap_err().to_string().contains("messages"));
331 }
332
333 #[test]
334 fn test_streaming() {
335 let request = RequestBuilder::new(Provider::OpenAI)
336 .model("gpt-4")
337 .api_key("sk-test")
338 .add_message(Message::user("Hello"))
339 .stream(true)
340 .build()
341 .unwrap();
342
343 assert!(request.body.contains(r#""stream":true"#));
344 }
345}