1use crate::{
7 exceptions::{LangExtractError, LangExtractResult},
8 logging::{report_progress, ProgressEvent},
9};
10use serde_json::Value;
11use std::collections::HashMap;
12use tokio::time::Duration;
13
14#[derive(Debug, Clone)]
16pub struct HttpConfig {
17 pub timeout_seconds: u64,
19 pub max_retries: usize,
21 pub base_delay_seconds: u64,
23 pub exponential_backoff: bool,
25 pub headers: HashMap<String, String>,
27}
28
29impl Default for HttpConfig {
30 fn default() -> Self {
31 Self {
32 timeout_seconds: 120,
33 max_retries: 3,
34 base_delay_seconds: 30,
35 exponential_backoff: true,
36 headers: HashMap::new(),
37 }
38 }
39}
40
41pub struct HttpClient {
43 client: reqwest::Client,
44 config: HttpConfig,
45}
46
47impl HttpClient {
48 pub fn new() -> Self {
50 Self::with_config(HttpConfig::default())
51 }
52
53 pub fn with_config(config: HttpConfig) -> Self {
55 let client = reqwest::Client::builder()
56 .timeout(Duration::from_secs(config.timeout_seconds))
57 .build()
58 .unwrap_or_else(|_| reqwest::Client::new());
59
60 Self { client, config }
61 }
62
63 pub async fn post_json_with_retry<T>(
65 &self,
66 url: &str,
67 body: &T,
68 operation_name: &str,
69 ) -> LangExtractResult<Value>
70 where
71 T: serde::Serialize,
72 {
73 self.retry_with_backoff(
74 || async {
75 self.post_json_single(url, body).await
76 },
77 operation_name,
78 ).await
79 }
80
81 async fn post_json_single<T>(&self, url: &str, body: &T) -> LangExtractResult<Value>
83 where
84 T: serde::Serialize,
85 {
86 let mut request = self.client.post(url).json(body);
87
88 for (key, value) in &self.config.headers {
90 request = request.header(key, value);
91 }
92
93 let response = request.send().await.map_err(|e| {
94 report_progress(ProgressEvent::Error {
95 operation: "HTTP request".to_string(),
96 error: format!("Request failed: {}", e),
97 });
98 LangExtractError::NetworkError(e)
99 })?;
100
101 if !response.status().is_success() {
102 let status = response.status();
103 let status_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
104
105 report_progress(ProgressEvent::Error {
106 operation: "HTTP response".to_string(),
107 error: format!("HTTP {} - {}", status, status_text),
108 });
109
110 return Err(LangExtractError::inference_simple(format!(
111 "HTTP error {}: {}",
112 status,
113 status_text
114 )));
115 }
116
117 let response_body: Value = response.json().await.map_err(|e| {
118 report_progress(ProgressEvent::Error {
119 operation: "JSON parsing".to_string(),
120 error: format!("Failed to parse response: {}", e),
121 });
122 LangExtractError::parsing(format!("Failed to parse JSON response: {}", e))
123 })?;
124
125 Ok(response_body)
126 }
127
128 async fn retry_with_backoff<T, F, Fut>(
130 &self,
131 mut operation: F,
132 operation_name: &str,
133 ) -> LangExtractResult<T>
134 where
135 F: FnMut() -> Fut,
136 Fut: std::future::Future<Output = LangExtractResult<T>>,
137 {
138 let max_retries = self.config.max_retries;
139 let base_delay = Duration::from_secs(self.config.base_delay_seconds);
140
141 for attempt in 0..=max_retries {
142 match operation().await {
143 Ok(result) => return Ok(result),
144 Err(e) => {
145 if attempt == max_retries {
146 return Err(LangExtractError::inference_simple(
148 format!("{} failed after {} attempts. Last error: {}",
149 operation_name, max_retries + 1, e)
150 ));
151 }
152
153 let delay = if self.config.exponential_backoff {
155 base_delay * (attempt + 1) as u32
156 } else {
157 base_delay
158 };
159
160 report_progress(ProgressEvent::RetryAttempt {
161 operation: operation_name.to_string(),
162 attempt: attempt + 1,
163 max_attempts: max_retries + 1,
164 delay_seconds: delay.as_secs(),
165 });
166
167 tokio::time::sleep(delay).await;
169 }
170 }
171 }
172
173 unreachable!("Should have returned from the loop")
174 }
175
176 pub fn with_header(mut self, key: String, value: String) -> Self {
178 self.config.headers.insert(key, value);
179 self
180 }
181
182 pub fn with_auth_header(self, auth_type: &str, token: &str) -> Self {
184 self.with_header("Authorization".to_string(), format!("{} {}", auth_type, token))
185 }
186
187 pub fn with_bearer_token(self, token: &str) -> Self {
189 self.with_auth_header("Bearer", token)
190 }
191
192 pub fn with_api_key(self, key: &str) -> Self {
194 self.with_header("X-API-Key".to_string(), key.to_string())
195 }
196}
197
198impl Default for HttpClient {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204impl HttpClient {
206 pub fn for_openai(api_key: &str) -> Self {
208 Self::new()
209 .with_bearer_token(api_key)
210 .with_header("Content-Type".to_string(), "application/json".to_string())
211 }
212
213 pub fn for_ollama() -> Self {
215 Self::with_config(HttpConfig {
216 timeout_seconds: 300, max_retries: 2, base_delay_seconds: 5, ..Default::default()
220 })
221 .with_header("Content-Type".to_string(), "application/json".to_string())
222 }
223
224 pub fn for_custom_provider(api_key: Option<&str>) -> Self {
226 let mut client = Self::with_config(HttpConfig {
227 timeout_seconds: 180,
228 max_retries: 3,
229 base_delay_seconds: 15,
230 ..Default::default()
231 })
232 .with_header("Content-Type".to_string(), "application/json".to_string());
233
234 if let Some(key) = api_key {
235 client = client.with_bearer_token(key);
236 }
237
238 client
239 }
240}
241
242pub struct RequestBuilder;
244
245impl RequestBuilder {
246 pub fn openai_chat_completion(
248 model: &str,
249 messages: Vec<serde_json::Value>,
250 temperature: Option<f32>,
251 max_tokens: Option<u32>,
252 ) -> serde_json::Value {
253 let mut request = serde_json::json!({
254 "model": model,
255 "messages": messages,
256 });
257
258 if let Some(temp) = temperature {
259 request["temperature"] = serde_json::json!(temp);
260 }
261
262 if let Some(tokens) = max_tokens {
263 request["max_tokens"] = serde_json::json!(tokens);
264 }
265
266 request
267 }
268
269 pub fn ollama_generate(
271 model: &str,
272 prompt: &str,
273 temperature: Option<f32>,
274 options: Option<&serde_json::Value>,
275 ) -> serde_json::Value {
276 let mut request = serde_json::json!({
277 "model": model,
278 "prompt": prompt,
279 "stream": false,
280 });
281
282 if let Some(temp) = temperature {
283 request["options"] = serde_json::json!({
284 "temperature": temp
285 });
286 }
287
288 if let Some(opts) = options {
289 if let Some(existing_opts) = request.get_mut("options") {
290 if let (Some(existing_map), Some(new_map)) = (existing_opts.as_object_mut(), opts.as_object()) {
292 for (key, value) in new_map {
293 existing_map.insert(key.clone(), value.clone());
294 }
295 }
296 } else {
297 request["options"] = opts.clone();
298 }
299 }
300
301 request
302 }
303
304 pub fn openai_system_message(content: &str) -> serde_json::Value {
306 serde_json::json!({
307 "role": "system",
308 "content": content
309 })
310 }
311
312 pub fn openai_user_message(content: &str) -> serde_json::Value {
314 serde_json::json!({
315 "role": "user",
316 "content": content
317 })
318 }
319}
320
321pub struct ResponseParser;
323
324impl ResponseParser {
325 pub fn openai_response_text(response: &Value) -> LangExtractResult<String> {
327 response
328 .get("choices")
329 .and_then(|choices| choices.as_array())
330 .and_then(|arr| arr.first())
331 .and_then(|choice| choice.get("message"))
332 .and_then(|message| message.get("content"))
333 .and_then(|content| content.as_str())
334 .map(|s| s.to_string())
335 .ok_or_else(|| LangExtractError::parsing("Invalid OpenAI response format"))
336 }
337
338 pub fn ollama_response_text(response: &Value) -> LangExtractResult<String> {
340 response
341 .get("response")
342 .and_then(|r| r.as_str())
343 .map(|s| s.to_string())
344 .ok_or_else(|| LangExtractError::parsing("Missing 'response' field in Ollama response"))
345 }
346
347 pub fn generic_response_text(response: &Value) -> LangExtractResult<String> {
349 let common_fields = ["response", "text", "content", "output", "result"];
351
352 for field in &common_fields {
353 if let Some(text) = response.get(field).and_then(|v| v.as_str()) {
354 return Ok(text.to_string());
355 }
356 }
357
358 if let Some(data) = response.get("data") {
360 return Self::generic_response_text(data);
361 }
362
363 Err(LangExtractError::parsing("Could not extract text from response"))
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_http_config_default() {
373 let config = HttpConfig::default();
374 assert_eq!(config.timeout_seconds, 120);
375 assert_eq!(config.max_retries, 3);
376 assert_eq!(config.base_delay_seconds, 30);
377 assert!(config.exponential_backoff);
378 }
379
380 #[test]
381 fn test_request_builder_openai() {
382 let messages = vec![
383 RequestBuilder::openai_system_message("You are helpful"),
384 RequestBuilder::openai_user_message("Hello"),
385 ];
386
387 let request = RequestBuilder::openai_chat_completion(
388 "gpt-4",
389 messages,
390 Some(0.7),
391 Some(100),
392 );
393
394 assert_eq!(request["model"], "gpt-4");
395 assert_eq!(request["temperature"], 0.7);
396 assert_eq!(request["max_tokens"], 100);
397 assert!(request["messages"].is_array());
398 }
399
400 #[test]
401 fn test_request_builder_ollama() {
402 let request = RequestBuilder::ollama_generate(
403 "mistral",
404 "Hello world",
405 Some(0.5),
406 None,
407 );
408
409 assert_eq!(request["model"], "mistral");
410 assert_eq!(request["prompt"], "Hello world");
411 assert_eq!(request["stream"], false);
412 assert_eq!(request["options"]["temperature"], 0.5);
413 }
414
415 #[test]
416 fn test_response_parser_openai() {
417 let response = serde_json::json!({
418 "choices": [{
419 "message": {
420 "content": "Hello, world!"
421 }
422 }]
423 });
424
425 let text = ResponseParser::openai_response_text(&response).unwrap();
426 assert_eq!(text, "Hello, world!");
427 }
428
429 #[test]
430 fn test_response_parser_ollama() {
431 let response = serde_json::json!({
432 "response": "Hello from Ollama!"
433 });
434
435 let text = ResponseParser::ollama_response_text(&response).unwrap();
436 assert_eq!(text, "Hello from Ollama!");
437 }
438
439 #[test]
440 fn test_response_parser_generic() {
441 let response1 = serde_json::json!({
442 "text": "Generic response"
443 });
444
445 let response2 = serde_json::json!({
446 "data": {
447 "content": "Nested response"
448 }
449 });
450
451 assert_eq!(ResponseParser::generic_response_text(&response1).unwrap(), "Generic response");
452 assert_eq!(ResponseParser::generic_response_text(&response2).unwrap(), "Nested response");
453 }
454}