1use log::{debug, error};
11use crate::error::ApiError;
12use crate::request::{Message, RequestBody};
13use reqwest::Client;
14use serde_json::{json, Number};
15use crate::response::{OpenAIResponse, ResponseMessage};
16use crate::tool::Tool;
17
18const API_ENDPOINT: &str = "https://api.anthropic.com/v1/messages";
19const API_VERSION: &str = "2023-06-01";
20const DEFAULT_ANTHROPIC_MODEL: &str = "claude-3-haiku-20240307";
21
22const DEFAULT_OPENAI_MODEL: &str = "gpt-4o";
23const DEFAULT_MAX_TOKENS: u32 = 100;
24const DEFAULT_TEMP: f64 = 0.0;
25
26#[derive(Debug, Clone)]
27pub enum ClientLlm {
29 Anthropic,
30 OpenAI,
31}
32
33#[async_trait::async_trait]
34pub trait LlmClientTrait: Send + Sync {
35 async fn send_message(&self, request_body: serde_json::Value) -> Result<ResponseMessage, ApiError>;
36 fn client_type(&self) -> ClientLlm;
37}
38
39pub struct RequestBuilder<'a> {
45 client: &'a (dyn LlmClientTrait + Send + Sync),
46 model: Option<String>,
47 messages: Option<Vec<Message>>,
48 max_tokens: Option<u32>,
49 temperature: Option<f64>,
50 system_prompt: Option<String>,
51 tools: Option<Vec<Tool>>
52}
53
54impl<'a> RequestBuilder<'a> {
55 pub fn new(client: &'a (dyn LlmClientTrait + Send + Sync)) -> Self {
56 RequestBuilder {
57 client,
58 model: None,
59 messages: None,
60 max_tokens: None,
61 temperature: None,
62 system_prompt: None,
63 tools: None,
64 }
65 }
66
67 pub fn add_tool(mut self, tool: Tool) -> Self {
68 if let Some(mut tools) = self.tools {
69 tools.push(tool);
70 self.tools = Some(tools);
71 } else {
72 self.tools = Some(vec![tool]);
73 }
74 self
75 }
76
77 pub fn model(mut self, model: &str) -> Self {
79 self.model = Some(model.to_string());
80 self
81 }
82
83 pub fn user_message(mut self, message: &str) -> Self {
85 if let Some(mut messages) = self.messages {
86 messages.push(Message {
87 role: "user".to_string(),
88 content: message.to_string(),
89 });
90 self.messages = Some(messages);
91 } else {
92 self.messages = Some(vec![Message {
93 role: "user".to_string(),
94 content: message.to_string(),
95 }]);
96 }
97 self
98 }
99
100 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
102 self.max_tokens = Some(max_tokens);
103 self
104 }
105
106 pub fn temperature(mut self, temperature: f64) -> Self {
108 self.temperature = Some(temperature);
109 self
110 }
111
112 pub fn system_prompt(mut self, system_prompt: &str) -> Self {
114 self.system_prompt = Some(system_prompt.to_string());
115 self
116 }
117
118 pub fn render_request(&self) -> Result<serde_json::Value, ApiError> {
119 let model = self.model.clone().unwrap_or_else(|| {
120 match self.client.client_type() {
121 ClientLlm::Anthropic => DEFAULT_ANTHROPIC_MODEL.to_string(),
122 ClientLlm::OpenAI => DEFAULT_OPENAI_MODEL.to_string(),
123 }
125 });
126 let messages = self.messages.clone().ok_or(ApiError::MissingMessages)?;
127 let max_tokens = self.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
128 let temperature = self.temperature.unwrap_or(DEFAULT_TEMP);
129 let temperature_number = Number::from_f64(temperature)
130 .ok_or_else(|| ApiError::InvalidUsage(format!("Invalid temperature value: {}", temperature)))?;
131 let system_prompt = self.system_prompt.clone().unwrap_or_default();
132
133 match self.client.client_type() {
134 ClientLlm::Anthropic => {
135 let mut request = json!({
136 "model": model,
137 "messages": messages,
138 "max_tokens": max_tokens,
139 "temperature": temperature_number,
140 "system": system_prompt,
141 });
142
143 if let Some(tools) = &self.tools {
144 let anthropic_tools: Vec<serde_json::Value> = tools.iter()
145 .map(|tool| tool.to_anthropic_format())
146 .collect();
147 request["tools"] = json!(anthropic_tools);
148 }
149
150 Ok(request)
151 },
152 ClientLlm::OpenAI => {
153 let mut request = json!({
154 "model": model,
155 "messages": messages,
156 "max_tokens": max_tokens,
157 "temperature": temperature_number,
158 });
159
160 if !system_prompt.is_empty() {
161 request["messages"].as_array_mut().unwrap().push(json!({
162 "role": "system",
163 "content": system_prompt
164 }));
165 }
166
167 if let Some(tools) = &self.tools {
168 let openai_tools: Vec<serde_json::Value> = tools.iter()
169 .map(|tool| tool.to_openai_format())
170 .collect();
171 request["tools"] = json!(openai_tools);
172 }
173
174 Ok(request)
175 },
176 }
177 }
178
179
180 pub async fn send(self) -> Result<ResponseMessage, ApiError> {
181 let request_body = self.render_request()?;
182 self.client.send_message(request_body).await
183 }
184}
185
186pub struct AnthropicClient {
188 api_key: String,
189 client: Client,
190}
191
192impl AnthropicClient {
193 pub fn new(api_key: String) -> Self {
194 let client = Client::new();
195 AnthropicClient { api_key, client }
196 }
197}
198
199#[async_trait::async_trait]
200impl LlmClientTrait for AnthropicClient {
201 async fn send_message(&self, request_body: serde_json::Value) -> Result<ResponseMessage, ApiError> {
202 let response = self.client
203 .post(API_ENDPOINT)
204 .header("x-api-key", &self.api_key)
205 .header("anthropic-version", API_VERSION)
206 .header("content-type", "application/json")
207 .json(&request_body)
208 .send()
209 .await?;
210 let resp_status = response.status();
211 let resp_text = response.text().await.unwrap_or("".into());
212 if resp_status.is_client_error() {
213 error!("Client error [{}]: {}", resp_status, resp_text);
214 return Err(ApiError::ClientError(
215 format!("Status: {} - Error: {}", resp_status, resp_text)));
216 } else if resp_status.is_server_error() {
217 error!("Server error [{}]: {}", resp_status, resp_text);
218 return Err(ApiError::ServerError(
219 format!("Status: {} - Error: {}", resp_status, resp_text)));
220 }
221 debug!("LLM call response: status[{}]\n{}", resp_status, resp_text);
222 let response_message = serde_json::from_str(&resp_text)?;
223
224 Ok(response_message)
225 }
226
227 fn client_type(&self) -> ClientLlm {
228 ClientLlm::Anthropic
229 }
230}
231
232pub struct OpenAIClient {
234 api_key: String,
235 client: Client,
236}
237
238impl OpenAIClient {
239 pub fn new(api_key: String) -> Self {
240 let client = Client::new();
241 OpenAIClient { api_key, client }
242 }
243}
244
245#[async_trait::async_trait]
246impl LlmClientTrait for OpenAIClient {
247 async fn send_message(&self, request_body: serde_json::Value) -> Result<ResponseMessage, ApiError> {
248 let response = self.client
249 .post("https://api.openai.com/v1/chat/completions")
250 .header("Authorization", format!("Bearer {}", self.api_key))
251 .header("Content-Type", "application/json")
252 .json(&request_body)
253 .send()
254 .await?;
255
256 let resp_status = response.status();
257 let resp_text = response.text().await.unwrap_or("".into());
258 if resp_status.is_client_error() {
259 return Err(ApiError::ClientError(format!("Status: {} - Error: {}", resp_status, resp_text)));
260 } else if resp_status.is_server_error() {
261 return Err(ApiError::ServerError(format!("Status: {} - Error: {}", resp_status, resp_text)));
262 }
263
264 let openai_response: OpenAIResponse = serde_json::from_str(&resp_text)?;
265 Ok(ResponseMessage::OpenAI(openai_response))
266 }
267
268 fn client_type(&self) -> ClientLlm {
269 ClientLlm::OpenAI
270 }
271}
272
273pub struct LlmClient {
279 client: Box<dyn LlmClientTrait + Send + Sync>,
280}
281
282impl LlmClient {
283 pub fn new(client_type: ClientLlm, api_key: String) -> Self {
285 let client: Box<dyn LlmClientTrait + Send + Sync> = match client_type {
286 ClientLlm::Anthropic => Box::new(AnthropicClient::new(api_key)),
287 ClientLlm::OpenAI => Box::new(OpenAIClient::new(api_key)),
288 };
289 LlmClient { client }
290 }
291
292 pub fn request(&mut self) -> RequestBuilder {
294 RequestBuilder::new(self.client.as_ref())
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use dotenv::dotenv;
301 use super::*;
302 use crate::tool::Tool;
303
304 struct MockClient {
305 client_type: ClientLlm,
306 }
307
308 #[async_trait::async_trait]
309 impl LlmClientTrait for MockClient {
310 async fn send_message(&self, _request_body: serde_json::Value) -> Result<ResponseMessage, ApiError> {
311 unimplemented!()
312 }
313
314 fn client_type(&self) -> ClientLlm {
315 self.client_type.clone()
316 }
317 }
318
319 #[test]
320 fn test_anthropic_default_request() {
321 let client = MockClient { client_type: ClientLlm::Anthropic };
322 let builder = RequestBuilder::new(&client)
323 .user_message("Hello, Claude!");
324
325 let request = builder.render_request().unwrap();
326
327 assert_eq!(request["model"], DEFAULT_ANTHROPIC_MODEL);
328 assert_eq!(request["max_tokens"], DEFAULT_MAX_TOKENS);
329 assert_eq!(request["temperature"], DEFAULT_TEMP);
330 assert_eq!(request["system"], "");
331 assert_eq!(request["messages"][0]["role"], "user");
332 assert_eq!(request["messages"][0]["content"], "Hello, Claude!");
333 }
334
335 #[test]
336 fn test_openai_default_request() {
337 let client = MockClient { client_type: ClientLlm::OpenAI };
338 let builder = RequestBuilder::new(&client)
339 .user_message("Hello, GPT!");
340
341 let request = builder.render_request().unwrap();
342
343 assert_eq!(request["model"], DEFAULT_OPENAI_MODEL);
344 assert_eq!(request["max_tokens"], DEFAULT_MAX_TOKENS);
345 assert_eq!(request["temperature"], DEFAULT_TEMP);
346 assert_eq!(request["messages"][0]["role"], "user");
347 assert_eq!(request["messages"][0]["content"], "Hello, GPT!");
348 }
349
350 #[test]
351 fn test_custom_model_and_parameters() {
352 let client = MockClient { client_type: ClientLlm::Anthropic };
353 let builder = RequestBuilder::new(&client)
354 .model("custom-model")
355 .max_tokens(500)
356 .temperature(0.8)
357 .system_prompt("You are a helpful assistant.")
358 .user_message("Tell me a joke.");
359
360 let request = builder.render_request().unwrap();
361
362 assert_eq!(request["model"], "custom-model");
363 assert_eq!(request["max_tokens"], 500);
364
365 assert_eq!(request["temperature"], json!(0.8));
367
368 assert_eq!(request["system"], "You are a helpful assistant.");
369 assert_eq!(request["messages"][0]["content"], "Tell me a joke.");
370 }
371
372 #[test]
373 fn test_multiple_messages() {
374 let client = MockClient { client_type: ClientLlm::OpenAI };
375 let builder = RequestBuilder::new(&client)
376 .user_message("Hello!")
377 .user_message("How are you?");
378
379 let request = builder.render_request().unwrap();
380
381 assert_eq!(request["messages"].as_array().unwrap().len(), 2);
382 assert_eq!(request["messages"][0]["content"], "Hello!");
383 assert_eq!(request["messages"][1]["content"], "How are you?");
384 }
385
386 #[test]
387 fn test_missing_messages() {
388 let client = MockClient { client_type: ClientLlm::Anthropic };
389 let builder = RequestBuilder::new(&client);
390
391 let result = builder.render_request();
392
393 assert!(matches!(result, Err(ApiError::MissingMessages)));
394 }
395
396 #[test]
397 fn test_openai_system_prompt() {
398 let client = MockClient { client_type: ClientLlm::OpenAI };
399 let builder = RequestBuilder::new(&client)
400 .system_prompt("You are a helpful assistant.")
401 .user_message("Hello!");
402
403 let request = builder.render_request().unwrap();
404
405 assert_eq!(request["messages"].as_array().unwrap().len(), 2);
406 assert_eq!(request["messages"][1]["role"], "system");
407 assert_eq!(request["messages"][1]["content"], "You are a helpful assistant.");
408 assert_eq!(request["messages"][0]["role"], "user");
409 assert_eq!(request["messages"][0]["content"], "Hello!");
410 }
411
412 #[test]
413 fn test_default_temperature() {
414 let client = MockClient { client_type: ClientLlm::Anthropic };
415 let builder = RequestBuilder::new(&client)
416 .user_message("Test message");
417
418 let request = builder.render_request().unwrap();
419
420 assert_eq!(request["temperature"], json!(DEFAULT_TEMP));
421 }
422
423 #[test]
424 fn test_custom_temperature() {
425 let client = MockClient { client_type: ClientLlm::Anthropic };
426 let custom_temp = 0.7;
427 let builder = RequestBuilder::new(&client)
428 .temperature(custom_temp)
429 .user_message("Test message");
430
431 let request = builder.render_request().unwrap();
432
433 assert_eq!(request["temperature"], json!(custom_temp));
434 }
435
436 #[test]
437 fn test_temperature_precision() {
438 let client = MockClient { client_type: ClientLlm::Anthropic };
439 let precise_temp = 0.12345;
440 let builder = RequestBuilder::new(&client)
441 .temperature(precise_temp)
442 .user_message("Test message");
443
444 let request = builder.render_request().unwrap();
445
446 assert_eq!(request["temperature"], json!(precise_temp));
447 }
448
449 #[test]
450 fn test_invalid_temperature() {
451 use std::f64::{INFINITY, NEG_INFINITY};
452
453 let client = MockClient { client_type: ClientLlm::Anthropic };
454
455 for &invalid_temp in &[INFINITY, NEG_INFINITY, f64::NAN] {
456 let builder = RequestBuilder::new(&client)
457 .temperature(invalid_temp)
458 .user_message("Test message");
459
460 let result = builder.render_request();
461 assert!(matches!(result, Err(ApiError::InvalidUsage(_))));
462 }
463 }
464
465 fn get_weather_tool() -> Tool {
466 Tool::builder()
467 .name("get_weather")
468 .description("Get the current weather in a given location")
469 .add_parameter("location", "string", "The city and state, e.g. San Francisco, CA", true)
470 .add_enum_parameter("unit", "The unit of temperature, either 'celsius' or 'fahrenheit'", false, vec!["celsius".to_string(), "fahrenheit".to_string()])
471 .build()
472 .expect("Failed to build tool")
473 }
474
475 #[test]
476 fn test_tool_use_anthropic() {
477 dotenv().ok();
478 let api_key = std::env::var("ANTHROPIC_API_KEY")
479 .expect("ANTHROPIC_API_KEY must be set.");
480 let client_type = ClientLlm::Anthropic;
481 let mut client = LlmClient::new(client_type, api_key);
482
483 let tool = get_weather_tool();
484
485 let request = client
486 .request()
487 .add_tool(tool)
488 .model("claude-3-haiku-20240307")
489 .user_message("What is the current weather in San Francisco, California")
490 .max_tokens(100)
491 .temperature(1.0)
492 .system_prompt("You are a haiku assistant.")
493 .render_request()
494 .expect("Failed to render request");
495
496 assert!(request.get("tools").is_some(), "Tools field is missing");
498 let tools = request["tools"].as_array().expect("Tools should be an array");
499 assert_eq!(tools.len(), 1, "There should be one tool");
500
501 let tool = &tools[0];
502 assert_eq!(tool["name"], "get_weather", "Tool name should be 'get_weather'");
503 assert!(tool["input_schema"].is_object(), "Tool should have an input schema");
504
505 let input_schema = &tool["input_schema"];
506 assert_eq!(input_schema["type"], "object", "Input schema type should be 'object'");
507
508 let properties = input_schema["properties"].as_object().expect("Properties should be an object");
509 assert!(properties.contains_key("location"), "Location parameter should be present");
510 assert!(properties.contains_key("unit"), "Unit parameter should be present");
511
512 }
513
514 #[test]
515 fn test_function_calling_openai() {
516 dotenv().ok();
517 let api_key = std::env::var("OPENAI_API_KEY")
518 .expect("OPENAI_API_KEY must be set.");
519 let client_type = ClientLlm::OpenAI;
520 let mut client = LlmClient::new(client_type, api_key);
521
522 let tool = get_weather_tool();
523
524 let request = client
525 .request()
526 .add_tool(tool)
527 .model("gpt-4o")
528 .user_message("What is the current weather in San Francisco, California")
529 .max_tokens(100)
530 .temperature(1.0)
531 .system_prompt("You are a weather assistant.")
532 .render_request()
533 .expect("Failed to render request");
534
535 assert!(request.get("tools").is_some(), "Tools field is missing");
537 let tools = request["tools"].as_array().expect("Tools should be an array");
538 assert_eq!(tools.len(), 1, "There should be one tool");
539
540 let function = &tools[0];
541 assert_eq!(function["type"], "function", "Tool type should be 'function'");
542
543 let function_details = &function["function"];
544 assert_eq!(function_details["name"], "get_weather", "Function name should be 'get_weather'");
545 assert_eq!(function_details["description"], "Get the current weather in a given location", "Function description should match");
546
547 let parameters = &function_details["parameters"];
548 assert_eq!(parameters["type"], "object", "Parameters type should be 'object'");
549
550 let properties = parameters["properties"].as_object().expect("Properties should be an object");
551 assert!(properties.contains_key("location"), "Location parameter should be present");
552 assert!(properties.contains_key("unit"), "Unit parameter should be present");
553
554 let location = &properties["location"];
555 assert_eq!(location["type"], "string", "Location type should be 'string'");
556
557 let unit = &properties["unit"];
558 assert_eq!(unit["type"], "string", "Unit type should be 'string'");
559 assert!(unit.get("enum").is_some(), "Unit should have enum values");
560
561 let required = parameters["required"].as_array().expect("Required should be an array");
562 assert!(required.contains(&json!("location")), "Location should be a required parameter");
563
564 assert_eq!(request["model"], "gpt-4o", "Model should be set correctly");
566 assert_eq!(request["max_tokens"], 100, "Max tokens should be set correctly");
567 assert_eq!(request["temperature"], 1.0, "Temperature should be set correctly");
568
569 let messages = request["messages"].as_array().expect("Messages should be an array");
571 assert!(messages.iter().any(|msg| msg["role"] == "system" && msg["content"] == "You are a weather assistant."),
572 "System message should be included in the messages array");
573 }
574}