langextract_rust/
http_client.rs

1//! HTTP client utilities for LangExtract providers.
2//!
3//! This module provides common HTTP functionality to reduce duplication
4//! across different provider implementations.
5
6use crate::{
7    exceptions::{LangExtractError, LangExtractResult},
8    logging::{report_progress, ProgressEvent},
9};
10use serde_json::Value;
11use std::collections::HashMap;
12use tokio::time::Duration;
13
14/// Configuration for HTTP requests
15#[derive(Debug, Clone)]
16pub struct HttpConfig {
17    /// Request timeout in seconds
18    pub timeout_seconds: u64,
19    /// Maximum number of retries
20    pub max_retries: usize,
21    /// Base delay between retries in seconds
22    pub base_delay_seconds: u64,
23    /// Whether to use exponential backoff
24    pub exponential_backoff: bool,
25    /// Custom headers to include in requests
26    pub headers: HashMap<String, String>,
27}
28
29impl Default for HttpConfig {
30    fn default() -> Self {
31        Self {
32            timeout_seconds: 120,
33            max_retries: 3,
34            base_delay_seconds: 30,
35            exponential_backoff: true,
36            headers: HashMap::new(),
37        }
38    }
39}
40
41/// HTTP client with retry logic and progress reporting
42pub struct HttpClient {
43    client: reqwest::Client,
44    config: HttpConfig,
45}
46
47impl HttpClient {
48    /// Create a new HTTP client with default configuration
49    pub fn new() -> Self {
50        Self::with_config(HttpConfig::default())
51    }
52
53    /// Create a new HTTP client with custom configuration
54    pub fn with_config(config: HttpConfig) -> Self {
55        let client = reqwest::Client::builder()
56            .timeout(Duration::from_secs(config.timeout_seconds))
57            .build()
58            .unwrap_or_else(|_| reqwest::Client::new());
59
60        Self { client, config }
61    }
62
63    /// POST request with JSON body and retry logic
64    pub async fn post_json_with_retry<T>(
65        &self,
66        url: &str,
67        body: &T,
68        operation_name: &str,
69    ) -> LangExtractResult<Value>
70    where
71        T: serde::Serialize,
72    {
73        self.retry_with_backoff(
74            || async {
75                self.post_json_single(url, body).await
76            },
77            operation_name,
78        ).await
79    }
80
81    /// Single POST request with JSON body
82    async fn post_json_single<T>(&self, url: &str, body: &T) -> LangExtractResult<Value>
83    where
84        T: serde::Serialize,
85    {
86        let mut request = self.client.post(url).json(body);
87
88        // Add custom headers
89        for (key, value) in &self.config.headers {
90            request = request.header(key, value);
91        }
92
93        let response = request.send().await.map_err(|e| {
94            report_progress(ProgressEvent::Error {
95                operation: "HTTP request".to_string(),
96                error: format!("Request failed: {}", e),
97            });
98            LangExtractError::NetworkError(e)
99        })?;
100
101        if !response.status().is_success() {
102            let status = response.status();
103            let status_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
104            
105            report_progress(ProgressEvent::Error {
106                operation: "HTTP response".to_string(),
107                error: format!("HTTP {} - {}", status, status_text),
108            });
109            
110            return Err(LangExtractError::inference_simple(format!(
111                "HTTP error {}: {}",
112                status,
113                status_text
114            )));
115        }
116
117        let response_body: Value = response.json().await.map_err(|e| {
118            report_progress(ProgressEvent::Error {
119                operation: "JSON parsing".to_string(),
120                error: format!("Failed to parse response: {}", e),
121            });
122            LangExtractError::parsing(format!("Failed to parse JSON response: {}", e))
123        })?;
124
125        Ok(response_body)
126    }
127
128    /// Retry helper function with exponential backoff
129    async fn retry_with_backoff<T, F, Fut>(
130        &self,
131        mut operation: F,
132        operation_name: &str,
133    ) -> LangExtractResult<T>
134    where
135        F: FnMut() -> Fut,
136        Fut: std::future::Future<Output = LangExtractResult<T>>,
137    {
138        let max_retries = self.config.max_retries;
139        let base_delay = Duration::from_secs(self.config.base_delay_seconds);
140
141        for attempt in 0..=max_retries {
142            match operation().await {
143                Ok(result) => return Ok(result),
144                Err(e) => {
145                    if attempt == max_retries {
146                        // Last attempt failed, return the error
147                        return Err(LangExtractError::inference_simple(
148                            format!("{} failed after {} attempts. Last error: {}", 
149                                operation_name, max_retries + 1, e)
150                        ));
151                    }
152
153                    // Calculate delay for next attempt
154                    let delay = if self.config.exponential_backoff {
155                        base_delay * (attempt + 1) as u32
156                    } else {
157                        base_delay
158                    };
159
160                    report_progress(ProgressEvent::RetryAttempt {
161                        operation: operation_name.to_string(),
162                        attempt: attempt + 1,
163                        max_attempts: max_retries + 1,
164                        delay_seconds: delay.as_secs(),
165                    });
166
167                    // Sleep before retry
168                    tokio::time::sleep(delay).await;
169                }
170            }
171        }
172
173        unreachable!("Should have returned from the loop")
174    }
175
176    /// Add a header to all requests
177    pub fn with_header(mut self, key: String, value: String) -> Self {
178        self.config.headers.insert(key, value);
179        self
180    }
181
182    /// Set authentication header
183    pub fn with_auth_header(self, auth_type: &str, token: &str) -> Self {
184        self.with_header("Authorization".to_string(), format!("{} {}", auth_type, token))
185    }
186
187    /// Set bearer token authentication
188    pub fn with_bearer_token(self, token: &str) -> Self {
189        self.with_auth_header("Bearer", token)
190    }
191
192    /// Set API key header
193    pub fn with_api_key(self, key: &str) -> Self {
194        self.with_header("X-API-Key".to_string(), key.to_string())
195    }
196}
197
198impl Default for HttpClient {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204/// Provider-specific HTTP client builders
205impl HttpClient {
206    /// Create HTTP client configured for OpenAI
207    pub fn for_openai(api_key: &str) -> Self {
208        Self::new()
209            .with_bearer_token(api_key)
210            .with_header("Content-Type".to_string(), "application/json".to_string())
211    }
212
213    /// Create HTTP client configured for Ollama
214    pub fn for_ollama() -> Self {
215        Self::with_config(HttpConfig {
216            timeout_seconds: 300, // Longer timeout for local inference
217            max_retries: 2,       // Fewer retries for local
218            base_delay_seconds: 5, // Shorter delays for local
219            ..Default::default()
220        })
221        .with_header("Content-Type".to_string(), "application/json".to_string())
222    }
223
224    /// Create HTTP client for custom providers
225    pub fn for_custom_provider(api_key: Option<&str>) -> Self {
226        let mut client = Self::with_config(HttpConfig {
227            timeout_seconds: 180,
228            max_retries: 3,
229            base_delay_seconds: 15,
230            ..Default::default()
231        })
232        .with_header("Content-Type".to_string(), "application/json".to_string());
233
234        if let Some(key) = api_key {
235            client = client.with_bearer_token(key);
236        }
237
238        client
239    }
240}
241
242/// Common request/response utilities
243pub struct RequestBuilder;
244
245impl RequestBuilder {
246    /// Build OpenAI-compatible chat completion request
247    pub fn openai_chat_completion(
248        model: &str,
249        messages: Vec<serde_json::Value>,
250        temperature: Option<f32>,
251        max_tokens: Option<u32>,
252    ) -> serde_json::Value {
253        let mut request = serde_json::json!({
254            "model": model,
255            "messages": messages,
256        });
257
258        if let Some(temp) = temperature {
259            request["temperature"] = serde_json::json!(temp);
260        }
261
262        if let Some(tokens) = max_tokens {
263            request["max_tokens"] = serde_json::json!(tokens);
264        }
265
266        request
267    }
268
269    /// Build Ollama generate request
270    pub fn ollama_generate(
271        model: &str,
272        prompt: &str,
273        temperature: Option<f32>,
274        options: Option<&serde_json::Value>,
275    ) -> serde_json::Value {
276        let mut request = serde_json::json!({
277            "model": model,
278            "prompt": prompt,
279            "stream": false,
280        });
281
282        if let Some(temp) = temperature {
283            request["options"] = serde_json::json!({
284                "temperature": temp
285            });
286        }
287
288        if let Some(opts) = options {
289            if let Some(existing_opts) = request.get_mut("options") {
290                // Merge options
291                if let (Some(existing_map), Some(new_map)) = (existing_opts.as_object_mut(), opts.as_object()) {
292                    for (key, value) in new_map {
293                        existing_map.insert(key.clone(), value.clone());
294                    }
295                }
296            } else {
297                request["options"] = opts.clone();
298            }
299        }
300
301        request
302    }
303
304    /// Create OpenAI system message
305    pub fn openai_system_message(content: &str) -> serde_json::Value {
306        serde_json::json!({
307            "role": "system",
308            "content": content
309        })
310    }
311
312    /// Create OpenAI user message
313    pub fn openai_user_message(content: &str) -> serde_json::Value {
314        serde_json::json!({
315            "role": "user",
316            "content": content
317        })
318    }
319}
320
321/// Response parser utilities
322pub struct ResponseParser;
323
324impl ResponseParser {
325    /// Extract text content from OpenAI response
326    pub fn openai_response_text(response: &Value) -> LangExtractResult<String> {
327        response
328            .get("choices")
329            .and_then(|choices| choices.as_array())
330            .and_then(|arr| arr.first())
331            .and_then(|choice| choice.get("message"))
332            .and_then(|message| message.get("content"))
333            .and_then(|content| content.as_str())
334            .map(|s| s.to_string())
335            .ok_or_else(|| LangExtractError::parsing("Invalid OpenAI response format"))
336    }
337
338    /// Extract text content from Ollama response
339    pub fn ollama_response_text(response: &Value) -> LangExtractResult<String> {
340        response
341            .get("response")
342            .and_then(|r| r.as_str())
343            .map(|s| s.to_string())
344            .ok_or_else(|| LangExtractError::parsing("Missing 'response' field in Ollama response"))
345    }
346
347    /// Generic response text extractor that tries common fields
348    pub fn generic_response_text(response: &Value) -> LangExtractResult<String> {
349        // Try common response field names
350        let common_fields = ["response", "text", "content", "output", "result"];
351        
352        for field in &common_fields {
353            if let Some(text) = response.get(field).and_then(|v| v.as_str()) {
354                return Ok(text.to_string());
355            }
356        }
357
358        // Try nested structures
359        if let Some(data) = response.get("data") {
360            return Self::generic_response_text(data);
361        }
362
363        Err(LangExtractError::parsing("Could not extract text from response"))
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_http_config_default() {
373        let config = HttpConfig::default();
374        assert_eq!(config.timeout_seconds, 120);
375        assert_eq!(config.max_retries, 3);
376        assert_eq!(config.base_delay_seconds, 30);
377        assert!(config.exponential_backoff);
378    }
379
380    #[test]
381    fn test_request_builder_openai() {
382        let messages = vec![
383            RequestBuilder::openai_system_message("You are helpful"),
384            RequestBuilder::openai_user_message("Hello"),
385        ];
386        
387        let request = RequestBuilder::openai_chat_completion(
388            "gpt-4",
389            messages,
390            Some(0.7),
391            Some(100),
392        );
393
394        assert_eq!(request["model"], "gpt-4");
395        assert_eq!(request["temperature"], 0.7);
396        assert_eq!(request["max_tokens"], 100);
397        assert!(request["messages"].is_array());
398    }
399
400    #[test]
401    fn test_request_builder_ollama() {
402        let request = RequestBuilder::ollama_generate(
403            "mistral",
404            "Hello world",
405            Some(0.5),
406            None,
407        );
408
409        assert_eq!(request["model"], "mistral");
410        assert_eq!(request["prompt"], "Hello world");
411        assert_eq!(request["stream"], false);
412        assert_eq!(request["options"]["temperature"], 0.5);
413    }
414
415    #[test]
416    fn test_response_parser_openai() {
417        let response = serde_json::json!({
418            "choices": [{
419                "message": {
420                    "content": "Hello, world!"
421                }
422            }]
423        });
424
425        let text = ResponseParser::openai_response_text(&response).unwrap();
426        assert_eq!(text, "Hello, world!");
427    }
428
429    #[test]
430    fn test_response_parser_ollama() {
431        let response = serde_json::json!({
432            "response": "Hello from Ollama!"
433        });
434
435        let text = ResponseParser::ollama_response_text(&response).unwrap();
436        assert_eq!(text, "Hello from Ollama!");
437    }
438
439    #[test]
440    fn test_response_parser_generic() {
441        let response1 = serde_json::json!({
442            "text": "Generic response"
443        });
444
445        let response2 = serde_json::json!({
446            "data": {
447                "content": "Nested response"
448            }
449        });
450
451        assert_eq!(ResponseParser::generic_response_text(&response1).unwrap(), "Generic response");
452        assert_eq!(ResponseParser::generic_response_text(&response2).unwrap(), "Nested response");
453    }
454}