Skip to main content

llm_bridge/
client.rs

1//! This module provides a client for interacting with different LLM APIs.
2//!
3//! The `LlmClient` struct is the main entry point for making requests to LLM APIs.
4//! It uses a `RequestBuilder` to construct the request parameters and sends the request
5//! using the appropriate client implementation based on the selected `ClientLlm` enum variant.
6//!
7//! The `LlmClientTrait` defines the common interface for sending messages to LLM APIs,
8//! and the `AnthropicClient` and `OpenAIClient` structs implement this trait for their respective APIs.
9
10use log::{debug, error};
11use crate::error::ApiError;
12use crate::request::{Message, RequestBody};
13use reqwest::Client;
14use serde_json::{json, Number};
15use crate::response::{OpenAIResponse, ResponseMessage};
16use crate::tool::Tool;
17
18const API_ENDPOINT: &str = "https://api.anthropic.com/v1/messages";
19const API_VERSION: &str = "2023-06-01";
20const DEFAULT_ANTHROPIC_MODEL: &str = "claude-3-haiku-20240307";
21
22const DEFAULT_OPENAI_MODEL: &str = "gpt-4o";
23const DEFAULT_MAX_TOKENS: u32 = 100;
24const DEFAULT_TEMP: f64 = 0.0;
25
26#[derive(Debug, Clone)]
27/// Supported LLMs
28pub enum ClientLlm {
29    Anthropic,
30    OpenAI,
31}
32
33#[async_trait::async_trait]
34pub trait LlmClientTrait: Send + Sync {
35    async fn send_message(&self, request_body: serde_json::Value) -> Result<ResponseMessage, ApiError>;
36    fn client_type(&self) -> ClientLlm;
37}
38
39/// Represents a builder for constructing a request to the Anthropic API.
40///
41/// The `RequestBuilder` allows setting various parameters for the request, such as the model,
42/// messages, max tokens, temperature, and system prompt. The `send` method sends the request
43/// to the API and returns the response.
44pub struct RequestBuilder<'a> {
45    client: &'a (dyn LlmClientTrait + Send + Sync),
46    model: Option<String>,
47    messages: Option<Vec<Message>>,
48    max_tokens: Option<u32>,
49    temperature: Option<f64>,
50    system_prompt: Option<String>,
51    tools: Option<Vec<Tool>>
52}
53
54impl<'a> RequestBuilder<'a> {
55    pub fn new(client: &'a (dyn LlmClientTrait + Send + Sync)) -> Self {
56        RequestBuilder {
57            client,
58            model: None,
59            messages: None,
60            max_tokens: None,
61            temperature: None,
62            system_prompt: None,
63            tools: None,
64        }
65    }
66
67    pub fn add_tool(mut self, tool: Tool) -> Self {
68        if let Some(mut tools) = self.tools {
69            tools.push(tool);
70            self.tools = Some(tools);
71        } else {
72            self.tools = Some(vec![tool]);
73        }
74        self
75    }
76
77    /// Sets the model to use for generating the response.
78    pub fn model(mut self, model: &str) -> Self {
79        self.model = Some(model.to_string());
80        self
81    }
82
83    /// Adds a user message to the conversation.
84    pub fn user_message(mut self, message: &str) -> Self {
85        if let Some(mut messages) = self.messages {
86            messages.push(Message {
87                role: "user".to_string(),
88                content: message.to_string(),
89            });
90            self.messages = Some(messages);
91        } else {
92            self.messages = Some(vec![Message {
93                role: "user".to_string(),
94                content: message.to_string(),
95            }]);
96        }
97        self
98    }
99
100    /// Sets the maximum number of tokens to generate in the response.
101    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
102        self.max_tokens = Some(max_tokens);
103        self
104    }
105
106    /// Sets the temperature value to control the randomness of the generated response.
107    pub fn temperature(mut self, temperature: f64) -> Self {
108        self.temperature = Some(temperature);
109        self
110    }
111
112    /// Sets the system prompt to provide context and instructions to the model.
113    pub fn system_prompt(mut self, system_prompt: &str) -> Self {
114        self.system_prompt = Some(system_prompt.to_string());
115        self
116    }
117
118    pub fn render_request(&self) -> Result<serde_json::Value, ApiError> {
119        let model = self.model.clone().unwrap_or_else(|| {
120            match self.client.client_type() {
121                ClientLlm::Anthropic => DEFAULT_ANTHROPIC_MODEL.to_string(),
122                ClientLlm::OpenAI => DEFAULT_OPENAI_MODEL.to_string(),
123                // Add more cases for other LLM APIs as needed
124            }
125        });
126        let messages = self.messages.clone().ok_or(ApiError::MissingMessages)?;
127        let max_tokens = self.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
128        let temperature = self.temperature.unwrap_or(DEFAULT_TEMP);
129        let temperature_number = Number::from_f64(temperature)
130            .ok_or_else(|| ApiError::InvalidUsage(format!("Invalid temperature value: {}", temperature)))?;
131        let system_prompt = self.system_prompt.clone().unwrap_or_default();
132
133        match self.client.client_type() {
134            ClientLlm::Anthropic => {
135                let mut request = json!({
136                    "model": model,
137                    "messages": messages,
138                    "max_tokens": max_tokens,
139                    "temperature": temperature_number,
140                    "system": system_prompt,
141                });
142
143                if let Some(tools) = &self.tools {
144                    let anthropic_tools: Vec<serde_json::Value> = tools.iter()
145                        .map(|tool| tool.to_anthropic_format())
146                        .collect();
147                    request["tools"] = json!(anthropic_tools);
148                }
149
150                Ok(request)
151            },
152            ClientLlm::OpenAI => {
153                let mut request = json!({
154                    "model": model,
155                    "messages": messages,
156                    "max_tokens": max_tokens,
157                    "temperature": temperature_number,
158                });
159
160                if !system_prompt.is_empty() {
161                    request["messages"].as_array_mut().unwrap().push(json!({
162                        "role": "system",
163                        "content": system_prompt
164                    }));
165                }
166
167                if let Some(tools) = &self.tools {
168                    let openai_tools: Vec<serde_json::Value> = tools.iter()
169                        .map(|tool| tool.to_openai_format())
170                        .collect();
171                    request["tools"] = json!(openai_tools);
172                }
173
174                Ok(request)
175            },
176        }
177    }
178
179
180    pub async fn send(self) -> Result<ResponseMessage, ApiError> {
181        let request_body = self.render_request()?;
182        self.client.send_message(request_body).await
183    }
184}
185
186/// Wrapper around the Anthropic LLM API client.
187pub struct AnthropicClient {
188    api_key: String,
189    client: Client,
190}
191
192impl AnthropicClient {
193    pub fn new(api_key: String) -> Self {
194        let client = Client::new();
195        AnthropicClient { api_key, client }
196    }
197}
198
199#[async_trait::async_trait]
200impl LlmClientTrait for AnthropicClient {
201    async fn send_message(&self, request_body: serde_json::Value) -> Result<ResponseMessage, ApiError> {
202        let response = self.client
203            .post(API_ENDPOINT)
204            .header("x-api-key", &self.api_key)
205            .header("anthropic-version", API_VERSION)
206            .header("content-type", "application/json")
207            .json(&request_body)
208            .send()
209            .await?;
210        let resp_status = response.status();
211        let resp_text = response.text().await.unwrap_or("".into());
212        if resp_status.is_client_error() {
213            error!("Client error [{}]: {}", resp_status, resp_text);
214            return Err(ApiError::ClientError(
215                format!("Status: {} - Error: {}", resp_status, resp_text)));
216        } else if resp_status.is_server_error() {
217            error!("Server error [{}]: {}", resp_status, resp_text);
218            return Err(ApiError::ServerError(
219                format!("Status: {} - Error: {}", resp_status, resp_text)));
220        }
221        debug!("LLM call response: status[{}]\n{}", resp_status, resp_text);
222        let response_message = serde_json::from_str(&resp_text)?;
223
224        Ok(response_message)
225    }
226
227    fn client_type(&self) -> ClientLlm {
228        ClientLlm::Anthropic
229    }
230}
231
232/// Wrapper around the OpenAI LLM API client.
233pub struct OpenAIClient {
234    api_key: String,
235    client: Client,
236}
237
238impl OpenAIClient {
239    pub fn new(api_key: String) -> Self {
240        let client = Client::new();
241        OpenAIClient { api_key, client }
242    }
243}
244
245#[async_trait::async_trait]
246impl LlmClientTrait for OpenAIClient {
247    async fn send_message(&self, request_body: serde_json::Value) -> Result<ResponseMessage, ApiError> {
248        let response = self.client
249            .post("https://api.openai.com/v1/chat/completions")
250            .header("Authorization", format!("Bearer {}", self.api_key))
251            .header("Content-Type", "application/json")
252            .json(&request_body)
253            .send()
254            .await?;
255
256        let resp_status = response.status();
257        let resp_text = response.text().await.unwrap_or("".into());
258        if resp_status.is_client_error() {
259            return Err(ApiError::ClientError(format!("Status: {} - Error: {}", resp_status, resp_text)));
260        } else if resp_status.is_server_error() {
261            return Err(ApiError::ServerError(format!("Status: {} - Error: {}", resp_status, resp_text)));
262        }
263
264        let openai_response: OpenAIResponse = serde_json::from_str(&resp_text)?;
265        Ok(ResponseMessage::OpenAI(openai_response))
266    }
267
268    fn client_type(&self) -> ClientLlm {
269        ClientLlm::OpenAI
270    }
271}
272
273/// The main client for interacting with LLM APIs.
274///
275/// The `LlmClient` struct provides a convenient way to make requests to LLM APIs using the
276/// `RequestBuilder`. It internally uses the appropriate client implementation based on the
277/// selected `ClientLlm` enum variant.
278pub struct LlmClient {
279    client: Box<dyn LlmClientTrait + Send + Sync>,
280}
281
282impl LlmClient {
283    /// Creates a new `LlmClient` instance with the specified `ClientLlm` variant and API key.
284    pub fn new(client_type: ClientLlm, api_key: String) -> Self {
285        let client: Box<dyn LlmClientTrait + Send + Sync> = match client_type {
286            ClientLlm::Anthropic => Box::new(AnthropicClient::new(api_key)),
287            ClientLlm::OpenAI => Box::new(OpenAIClient::new(api_key)),
288        };
289        LlmClient { client }
290    }
291
292    /// Creates a new `RequestBuilder` for constructing a request to the LLM API.
293    pub fn request(&mut self) -> RequestBuilder {
294        RequestBuilder::new(self.client.as_ref())
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use dotenv::dotenv;
301    use super::*;
302    use crate::tool::Tool;
303
304    struct MockClient {
305        client_type: ClientLlm,
306    }
307
308    #[async_trait::async_trait]
309    impl LlmClientTrait for MockClient {
310        async fn send_message(&self, _request_body: serde_json::Value) -> Result<ResponseMessage, ApiError> {
311            unimplemented!()
312        }
313
314        fn client_type(&self) -> ClientLlm {
315            self.client_type.clone()
316        }
317    }
318
319    #[test]
320    fn test_anthropic_default_request() {
321        let client = MockClient { client_type: ClientLlm::Anthropic };
322        let builder = RequestBuilder::new(&client)
323            .user_message("Hello, Claude!");
324
325        let request = builder.render_request().unwrap();
326
327        assert_eq!(request["model"], DEFAULT_ANTHROPIC_MODEL);
328        assert_eq!(request["max_tokens"], DEFAULT_MAX_TOKENS);
329        assert_eq!(request["temperature"], DEFAULT_TEMP);
330        assert_eq!(request["system"], "");
331        assert_eq!(request["messages"][0]["role"], "user");
332        assert_eq!(request["messages"][0]["content"], "Hello, Claude!");
333    }
334
335    #[test]
336    fn test_openai_default_request() {
337        let client = MockClient { client_type: ClientLlm::OpenAI };
338        let builder = RequestBuilder::new(&client)
339            .user_message("Hello, GPT!");
340
341        let request = builder.render_request().unwrap();
342
343        assert_eq!(request["model"], DEFAULT_OPENAI_MODEL);
344        assert_eq!(request["max_tokens"], DEFAULT_MAX_TOKENS);
345        assert_eq!(request["temperature"], DEFAULT_TEMP);
346        assert_eq!(request["messages"][0]["role"], "user");
347        assert_eq!(request["messages"][0]["content"], "Hello, GPT!");
348    }
349
350    #[test]
351    fn test_custom_model_and_parameters() {
352        let client = MockClient { client_type: ClientLlm::Anthropic };
353        let builder = RequestBuilder::new(&client)
354            .model("custom-model")
355            .max_tokens(500)
356            .temperature(0.8)
357            .system_prompt("You are a helpful assistant.")
358            .user_message("Tell me a joke.");
359
360        let request = builder.render_request().unwrap();
361
362        assert_eq!(request["model"], "custom-model");
363        assert_eq!(request["max_tokens"], 500);
364
365        // Check for exact temperature value
366        assert_eq!(request["temperature"], json!(0.8));
367
368        assert_eq!(request["system"], "You are a helpful assistant.");
369        assert_eq!(request["messages"][0]["content"], "Tell me a joke.");
370    }
371
372    #[test]
373    fn test_multiple_messages() {
374        let client = MockClient { client_type: ClientLlm::OpenAI };
375        let builder = RequestBuilder::new(&client)
376            .user_message("Hello!")
377            .user_message("How are you?");
378
379        let request = builder.render_request().unwrap();
380
381        assert_eq!(request["messages"].as_array().unwrap().len(), 2);
382        assert_eq!(request["messages"][0]["content"], "Hello!");
383        assert_eq!(request["messages"][1]["content"], "How are you?");
384    }
385
386    #[test]
387    fn test_missing_messages() {
388        let client = MockClient { client_type: ClientLlm::Anthropic };
389        let builder = RequestBuilder::new(&client);
390
391        let result = builder.render_request();
392
393        assert!(matches!(result, Err(ApiError::MissingMessages)));
394    }
395
396    #[test]
397    fn test_openai_system_prompt() {
398        let client = MockClient { client_type: ClientLlm::OpenAI };
399        let builder = RequestBuilder::new(&client)
400            .system_prompt("You are a helpful assistant.")
401            .user_message("Hello!");
402
403        let request = builder.render_request().unwrap();
404
405        assert_eq!(request["messages"].as_array().unwrap().len(), 2);
406        assert_eq!(request["messages"][1]["role"], "system");
407        assert_eq!(request["messages"][1]["content"], "You are a helpful assistant.");
408        assert_eq!(request["messages"][0]["role"], "user");
409        assert_eq!(request["messages"][0]["content"], "Hello!");
410    }
411
412    #[test]
413    fn test_default_temperature() {
414        let client = MockClient { client_type: ClientLlm::Anthropic };
415        let builder = RequestBuilder::new(&client)
416            .user_message("Test message");
417
418        let request = builder.render_request().unwrap();
419
420        assert_eq!(request["temperature"], json!(DEFAULT_TEMP));
421    }
422
423    #[test]
424    fn test_custom_temperature() {
425        let client = MockClient { client_type: ClientLlm::Anthropic };
426        let custom_temp = 0.7;
427        let builder = RequestBuilder::new(&client)
428            .temperature(custom_temp)
429            .user_message("Test message");
430
431        let request = builder.render_request().unwrap();
432
433        assert_eq!(request["temperature"], json!(custom_temp));
434    }
435
436    #[test]
437    fn test_temperature_precision() {
438        let client = MockClient { client_type: ClientLlm::Anthropic };
439        let precise_temp = 0.12345;
440        let builder = RequestBuilder::new(&client)
441            .temperature(precise_temp)
442            .user_message("Test message");
443
444        let request = builder.render_request().unwrap();
445
446        assert_eq!(request["temperature"], json!(precise_temp));
447    }
448
449    #[test]
450    fn test_invalid_temperature() {
451        use std::f64::{INFINITY, NEG_INFINITY};
452
453        let client = MockClient { client_type: ClientLlm::Anthropic };
454
455        for &invalid_temp in &[INFINITY, NEG_INFINITY, f64::NAN] {
456            let builder = RequestBuilder::new(&client)
457                .temperature(invalid_temp)
458                .user_message("Test message");
459
460            let result = builder.render_request();
461            assert!(matches!(result, Err(ApiError::InvalidUsage(_))));
462        }
463    }
464    
465    fn get_weather_tool() -> Tool {
466        Tool::builder()
467            .name("get_weather")
468            .description("Get the current weather in a given location")
469            .add_parameter("location", "string", "The city and state, e.g. San Francisco, CA", true)
470            .add_enum_parameter("unit", "The unit of temperature, either 'celsius' or 'fahrenheit'", false, vec!["celsius".to_string(), "fahrenheit".to_string()])
471            .build()
472            .expect("Failed to build tool")
473    }
474
475    #[test]
476    fn test_tool_use_anthropic() {
477        dotenv().ok();
478        let api_key = std::env::var("ANTHROPIC_API_KEY")
479            .expect("ANTHROPIC_API_KEY must be set.");
480        let client_type = ClientLlm::Anthropic;
481        let mut client = LlmClient::new(client_type, api_key);
482
483        let tool = get_weather_tool();
484
485        let request = client
486            .request()
487            .add_tool(tool)
488            .model("claude-3-haiku-20240307")
489            .user_message("What is the current weather in San Francisco, California")
490            .max_tokens(100)
491            .temperature(1.0)
492            .system_prompt("You are a haiku assistant.")
493            .render_request()
494            .expect("Failed to render request");
495
496        // Check if the tools field is present and correctly formatted
497        assert!(request.get("tools").is_some(), "Tools field is missing");
498        let tools = request["tools"].as_array().expect("Tools should be an array");
499        assert_eq!(tools.len(), 1, "There should be one tool");
500
501        let tool = &tools[0];
502        assert_eq!(tool["name"], "get_weather", "Tool name should be 'get_weather'");
503        assert!(tool["input_schema"].is_object(), "Tool should have an input schema");
504
505        let input_schema = &tool["input_schema"];
506        assert_eq!(input_schema["type"], "object", "Input schema type should be 'object'");
507
508        let properties = input_schema["properties"].as_object().expect("Properties should be an object");
509        assert!(properties.contains_key("location"), "Location parameter should be present");
510        assert!(properties.contains_key("unit"), "Unit parameter should be present");
511
512    }
513
514    #[test]
515    fn test_function_calling_openai() {
516        dotenv().ok();
517        let api_key = std::env::var("OPENAI_API_KEY")
518            .expect("OPENAI_API_KEY must be set.");
519        let client_type = ClientLlm::OpenAI;
520        let mut client = LlmClient::new(client_type, api_key);
521
522        let tool = get_weather_tool();
523
524        let request = client
525            .request()
526            .add_tool(tool)
527            .model("gpt-4o")
528            .user_message("What is the current weather in San Francisco, California")
529            .max_tokens(100)
530            .temperature(1.0)
531            .system_prompt("You are a weather assistant.")
532            .render_request()
533            .expect("Failed to render request");
534
535        // Check if the functions field is present and correctly formatted
536        assert!(request.get("tools").is_some(), "Tools field is missing");
537        let tools = request["tools"].as_array().expect("Tools should be an array");
538        assert_eq!(tools.len(), 1, "There should be one tool");
539
540        let function = &tools[0];
541        assert_eq!(function["type"], "function", "Tool type should be 'function'");
542
543        let function_details = &function["function"];
544        assert_eq!(function_details["name"], "get_weather", "Function name should be 'get_weather'");
545        assert_eq!(function_details["description"], "Get the current weather in a given location", "Function description should match");
546
547        let parameters = &function_details["parameters"];
548        assert_eq!(parameters["type"], "object", "Parameters type should be 'object'");
549
550        let properties = parameters["properties"].as_object().expect("Properties should be an object");
551        assert!(properties.contains_key("location"), "Location parameter should be present");
552        assert!(properties.contains_key("unit"), "Unit parameter should be present");
553
554        let location = &properties["location"];
555        assert_eq!(location["type"], "string", "Location type should be 'string'");
556
557        let unit = &properties["unit"];
558        assert_eq!(unit["type"], "string", "Unit type should be 'string'");
559        assert!(unit.get("enum").is_some(), "Unit should have enum values");
560
561        let required = parameters["required"].as_array().expect("Required should be an array");
562        assert!(required.contains(&json!("location")), "Location should be a required parameter");
563
564        // Check other request parameters
565        assert_eq!(request["model"], "gpt-4o", "Model should be set correctly");
566        assert_eq!(request["max_tokens"], 100, "Max tokens should be set correctly");
567        assert_eq!(request["temperature"], 1.0, "Temperature should be set correctly");
568
569        // Check that the system message is included in the messages array
570        let messages = request["messages"].as_array().expect("Messages should be an array");
571        assert!(messages.iter().any(|msg| msg["role"] == "system" && msg["content"] == "You are a weather assistant."),
572                "System message should be included in the messages array");
573    }
574}