Skip to main content

cortexai_llm_client/
request.rs

1//! Request building for LLM APIs.
2
3use crate::{LlmClientError, Message, Provider, Result};
4use serde::Serialize;
5
6/// An HTTP request ready to be sent.
7///
8/// This struct contains all the information needed to make an HTTP request,
9/// but does NOT include any HTTP client implementation.
10#[derive(Debug, Clone)]
11pub struct HttpRequest {
12    /// HTTP method (always POST for LLM APIs)
13    pub method: &'static str,
14    /// Full URL
15    pub url: String,
16    /// Headers as key-value pairs
17    pub headers: Vec<(String, String)>,
18    /// JSON body as string
19    pub body: String,
20}
21
22/// Builder for constructing LLM API requests.
23#[derive(Debug, Clone)]
24pub struct RequestBuilder {
25    provider: Provider,
26    model: Option<String>,
27    messages: Vec<Message>,
28    api_key: Option<String>,
29    temperature: Option<f32>,
30    max_tokens: Option<u32>,
31    stream: bool,
32    top_p: Option<f32>,
33    stop: Option<Vec<String>>,
34}
35
36impl RequestBuilder {
37    /// Create a new request builder for the given provider.
38    pub fn new(provider: Provider) -> Self {
39        Self {
40            provider,
41            model: None,
42            messages: Vec::new(),
43            api_key: None,
44            temperature: None,
45            max_tokens: None,
46            stream: false,
47            top_p: None,
48            stop: None,
49        }
50    }
51
52    /// Set the model to use.
53    pub fn model(mut self, model: impl Into<String>) -> Self {
54        self.model = Some(model.into());
55        self
56    }
57
58    /// Set the messages for the conversation.
59    pub fn messages(mut self, messages: &[Message]) -> Self {
60        self.messages = messages.to_vec();
61        self
62    }
63
64    /// Add a single message to the conversation.
65    pub fn add_message(mut self, message: Message) -> Self {
66        self.messages.push(message);
67        self
68    }
69
70    /// Set the API key.
71    pub fn api_key(mut self, key: impl Into<String>) -> Self {
72        self.api_key = Some(key.into());
73        self
74    }
75
76    /// Set the temperature (0.0 - 2.0).
77    pub fn temperature(mut self, temp: f32) -> Self {
78        self.temperature = Some(temp.clamp(0.0, 2.0));
79        self
80    }
81
82    /// Set the maximum tokens to generate.
83    pub fn max_tokens(mut self, tokens: u32) -> Self {
84        self.max_tokens = Some(tokens);
85        self
86    }
87
88    /// Enable or disable streaming.
89    pub fn stream(mut self, enable: bool) -> Self {
90        self.stream = enable;
91        self
92    }
93
94    /// Set top_p (nucleus sampling).
95    pub fn top_p(mut self, p: f32) -> Self {
96        self.top_p = Some(p.clamp(0.0, 1.0));
97        self
98    }
99
100    /// Set stop sequences.
101    pub fn stop(mut self, sequences: Vec<String>) -> Self {
102        self.stop = Some(sequences);
103        self
104    }
105
106    /// Build the HTTP request.
107    pub fn build(&self) -> Result<HttpRequest> {
108        let model = self
109            .model
110            .as_ref()
111            .ok_or_else(|| LlmClientError::missing("model"))?;
112        let api_key = self
113            .api_key
114            .as_ref()
115            .ok_or_else(|| LlmClientError::missing("api_key"))?;
116
117        if self.messages.is_empty() {
118            return Err(LlmClientError::missing("messages"));
119        }
120
121        let url = self.provider.endpoint().to_string();
122        let headers = self.build_headers(api_key);
123        let body = self.build_body(model)?;
124
125        Ok(HttpRequest {
126            method: "POST",
127            url,
128            headers,
129            body,
130        })
131    }
132
133    fn build_headers(&self, api_key: &str) -> Vec<(String, String)> {
134        let mut headers = vec![
135            ("Content-Type".to_string(), "application/json".to_string()),
136            (
137                self.provider.auth_header().to_string(),
138                self.provider.format_auth(api_key),
139            ),
140        ];
141
142        for (key, value) in self.provider.extra_headers() {
143            headers.push((key.to_string(), value.to_string()));
144        }
145
146        headers
147    }
148
149    fn build_body(&self, model: &str) -> Result<String> {
150        match self.provider {
151            Provider::OpenAI | Provider::OpenRouter => self.build_openai_body(model),
152            Provider::Anthropic => self.build_anthropic_body(model),
153        }
154    }
155
156    fn build_openai_body(&self, model: &str) -> Result<String> {
157        #[derive(Serialize)]
158        struct OpenAIRequest<'a> {
159            model: &'a str,
160            messages: &'a [OpenAIMessage<'a>],
161            #[serde(skip_serializing_if = "Option::is_none")]
162            temperature: Option<f32>,
163            #[serde(skip_serializing_if = "Option::is_none")]
164            max_tokens: Option<u32>,
165            stream: bool,
166            #[serde(skip_serializing_if = "Option::is_none")]
167            top_p: Option<f32>,
168            #[serde(skip_serializing_if = "Option::is_none")]
169            stop: Option<&'a [String]>,
170        }
171
172        #[derive(Serialize)]
173        struct OpenAIMessage<'a> {
174            role: &'a str,
175            content: &'a str,
176        }
177
178        let messages: Vec<OpenAIMessage> = self
179            .messages
180            .iter()
181            .map(|m| OpenAIMessage {
182                role: m.role.as_str(),
183                content: &m.content,
184            })
185            .collect();
186
187        let request = OpenAIRequest {
188            model,
189            messages: &messages,
190            temperature: self.temperature,
191            max_tokens: self.max_tokens,
192            stream: self.stream,
193            top_p: self.top_p,
194            stop: self.stop.as_deref(),
195        };
196
197        Ok(serde_json::to_string(&request)?)
198    }
199
200    fn build_anthropic_body(&self, model: &str) -> Result<String> {
201        #[derive(Serialize)]
202        struct AnthropicRequest<'a> {
203            model: &'a str,
204            #[serde(skip_serializing_if = "Option::is_none")]
205            system: Option<&'a str>,
206            messages: Vec<AnthropicMessage<'a>>,
207            max_tokens: u32,
208            #[serde(skip_serializing_if = "Option::is_none")]
209            temperature: Option<f32>,
210            stream: bool,
211            #[serde(skip_serializing_if = "Option::is_none")]
212            top_p: Option<f32>,
213            #[serde(skip_serializing_if = "Option::is_none")]
214            stop_sequences: Option<&'a [String]>,
215        }
216
217        #[derive(Serialize)]
218        struct AnthropicMessage<'a> {
219            role: &'a str,
220            content: &'a str,
221        }
222
223        // Extract system message (Anthropic handles it separately)
224        let system = self
225            .messages
226            .iter()
227            .find(|m| m.role == crate::Role::System)
228            .map(|m| m.content.as_str());
229
230        // Filter out system messages for the messages array
231        let messages: Vec<AnthropicMessage> = self
232            .messages
233            .iter()
234            .filter(|m| m.role != crate::Role::System)
235            .map(|m| AnthropicMessage {
236                role: if m.role == crate::Role::User {
237                    "user"
238                } else {
239                    "assistant"
240                },
241                content: &m.content,
242            })
243            .collect();
244
245        let request = AnthropicRequest {
246            model,
247            system,
248            messages,
249            max_tokens: self.max_tokens.unwrap_or(4096),
250            temperature: self.temperature,
251            stream: self.stream,
252            top_p: self.top_p,
253            stop_sequences: self.stop.as_deref(),
254        };
255
256        Ok(serde_json::to_string(&request)?)
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_openai_request() {
266        let request = RequestBuilder::new(Provider::OpenAI)
267            .model("gpt-4o-mini")
268            .api_key("sk-test")
269            .add_message(Message::system("You are helpful"))
270            .add_message(Message::user("Hello"))
271            .temperature(0.7)
272            .max_tokens(1024)
273            .build()
274            .unwrap();
275
276        assert_eq!(request.method, "POST");
277        assert!(request.url.contains("openai.com"));
278        assert!(request.body.contains("gpt-4o-mini"));
279        assert!(request.body.contains("Hello"));
280
281        // Check headers
282        let auth_header = request.headers.iter().find(|(k, _)| k == "Authorization");
283        assert!(auth_header.is_some());
284        assert!(auth_header.unwrap().1.starts_with("Bearer "));
285    }
286
287    #[test]
288    fn test_anthropic_request() {
289        let request = RequestBuilder::new(Provider::Anthropic)
290            .model("claude-3-sonnet-20240229")
291            .api_key("sk-ant-test")
292            .add_message(Message::system("You are helpful"))
293            .add_message(Message::user("Hello"))
294            .max_tokens(1024)
295            .build()
296            .unwrap();
297
298        assert!(request.url.contains("anthropic.com"));
299        assert!(request.body.contains("claude-3"));
300        // Anthropic puts system message separately
301        assert!(request.body.contains(r#""system":"You are helpful"#));
302
303        // Check anthropic-version header
304        let version_header = request
305            .headers
306            .iter()
307            .find(|(k, _)| k == "anthropic-version");
308        assert!(version_header.is_some());
309    }
310
311    #[test]
312    fn test_missing_model() {
313        let result = RequestBuilder::new(Provider::OpenAI)
314            .api_key("sk-test")
315            .add_message(Message::user("Hello"))
316            .build();
317
318        assert!(result.is_err());
319        assert!(result.unwrap_err().to_string().contains("model"));
320    }
321
322    #[test]
323    fn test_missing_messages() {
324        let result = RequestBuilder::new(Provider::OpenAI)
325            .model("gpt-4")
326            .api_key("sk-test")
327            .build();
328
329        assert!(result.is_err());
330        assert!(result.unwrap_err().to_string().contains("messages"));
331    }
332
333    #[test]
334    fn test_streaming() {
335        let request = RequestBuilder::new(Provider::OpenAI)
336            .model("gpt-4")
337            .api_key("sk-test")
338            .add_message(Message::user("Hello"))
339            .stream(true)
340            .build()
341            .unwrap();
342
343        assert!(request.body.contains(r#""stream":true"#));
344    }
345}