nabla_cli/enterprise/providers/
http.rs

1// src/enterprise/providers/http.rs
2use reqwest::Client;
3use serde_json::json;
4use async_trait::async_trait;
5use super::types::{InferenceProvider, GenerationOptions, GenerationResponse, InferenceError};
6
7#[derive(Clone)]
8pub struct HTTPProvider {
9    client: Client,
10    inference_url: String,
11    api_key: Option<String>,
12    provider_token: Option<String>,
13}
14
15impl HTTPProvider {
16    pub fn new(inference_url: String, api_key: Option<String>, provider_token: Option<String>) -> Self {
17        Self {
18            client: Client::new(),
19            inference_url,
20            api_key,
21            provider_token,
22        }
23    }
24}
25
26#[async_trait]
27impl InferenceProvider for HTTPProvider {
28    async fn generate(&self, prompt: &str, options: &GenerationOptions) -> Result<GenerationResponse, InferenceError> {
29        // Check if this is a Hugging Face repo request or has a model path (llama.cpp server)
30        if options.hf_repo.is_some() || options.model_path.is_some() {
31            // Use llama.cpp server's completion endpoint
32            let mut request_json = json!({
33                "prompt": prompt,
34                "n_predict": options.max_tokens,
35                "temperature": options.temperature,
36                "top_p": options.top_p,
37                "stop": options.stop_sequences,
38            });
39            
40            // Add hf_repo if provided
41            if let Some(hf_repo) = &options.hf_repo {
42                request_json = json!({
43                    "prompt": prompt,
44                    "n_predict": options.max_tokens,
45                    "temperature": options.temperature,
46                    "top_p": options.top_p,
47                    "stop": options.stop_sequences,
48                    "hf_repo": hf_repo,
49                });
50            }
51            
52            let mut request = self.client
53                .post(&format!("{}/completion", self.inference_url))
54                .json(&request_json);
55                
56            // Add authentication headers
57            if let Some(key) = &self.api_key {
58                request = request.header("Authorization", format!("Bearer {}", key));
59            } else if let Some(token) = &self.provider_token {
60                request = request.header("Authorization", format!("Bearer {}", token));
61            }
62            
63            let response = request.send().await
64                .map_err(|e| InferenceError::NetworkError(e.to_string()))?;
65                
66            if !response.status().is_success() {
67                return Err(InferenceError::ServerError(format!(
68                    "Server returned status: {}", response.status()
69                )));
70            }
71            
72            let result: serde_json::Value = response.json().await
73                .map_err(|e| InferenceError::ServerError(format!("Failed to parse response: {}", e)))?;
74            
75            // Parse llama.cpp server response format
76            let content = result["content"].as_str()
77                .ok_or_else(|| InferenceError::ServerError("Missing content in response".to_string()))?;
78            let tokens_used = result["tokens_predicted"].as_u64().unwrap_or(0) as usize;
79            let stop_reason = result["stop_type"].as_str().unwrap_or("").to_string();
80            
81            Ok(GenerationResponse {
82                text: content.to_string(),
83                tokens_used,
84                finish_reason: stop_reason,
85            })
86        } else {
87            // Fallback to OpenAI-compatible endpoint
88            let mut request = self.client
89                .post(&format!("{}/v1/chat/completions", self.inference_url))
90                .json(&json!({
91                    "model": options.model.as_deref().unwrap_or("gpt-3.5-turbo"),
92                    "messages": [{"role": "user", "content": prompt}],
93                    "max_tokens": options.max_tokens,
94                    "temperature": options.temperature,
95                    "top_p": options.top_p,
96                }));
97                
98            // Add authentication headers
99            if let Some(key) = &self.api_key {
100                request = request.header("Authorization", format!("Bearer {}", key));
101            } else if let Some(token) = &self.provider_token {
102                request = request.header("Authorization", format!("Bearer {}", token));
103            }
104            
105            let response = request.send().await
106                .map_err(|e| InferenceError::NetworkError(e.to_string()))?;
107                
108            if !response.status().is_success() {
109                return Err(InferenceError::ServerError(format!(
110                    "Server returned status: {}", response.status()
111                )));
112            }
113            
114            let result: serde_json::Value = response.json().await
115                .map_err(|e| InferenceError::ServerError(format!("Failed to parse response: {}", e)))?;
116            
117            // Parse OpenAI-compatible response format
118            let content = result["choices"][0]["message"]["content"].as_str()
119                .ok_or_else(|| InferenceError::ServerError("Missing content in response".to_string()))?;
120            let tokens_used = result["usage"]["total_tokens"].as_u64().unwrap_or(0) as usize;
121            let finish_reason = result["choices"][0]["finish_reason"].as_str().unwrap_or("").to_string();
122            
123            Ok(GenerationResponse {
124                text: content.to_string(),
125                tokens_used,
126                finish_reason,
127            })
128        }
129    }
130
131    async fn is_available(&self) -> bool {
132        // Try to ping the server
133        match self.client.get(&format!("{}/props", self.inference_url)).send().await {
134            Ok(response) => response.status().is_success(),
135            Err(_) => false,
136        }
137    }
138}