nabla_cli/enterprise/providers/
http.rs1use reqwest::Client;
3use serde_json::json;
4use async_trait::async_trait;
5use super::types::{InferenceProvider, GenerationOptions, GenerationResponse, InferenceError};
6
7#[derive(Clone)]
8pub struct HTTPProvider {
9 client: Client,
10 inference_url: String,
11 api_key: Option<String>,
12 provider_token: Option<String>,
13}
14
15impl HTTPProvider {
16 pub fn new(inference_url: String, api_key: Option<String>, provider_token: Option<String>) -> Self {
17 Self {
18 client: Client::new(),
19 inference_url,
20 api_key,
21 provider_token,
22 }
23 }
24}
25
26#[async_trait]
27impl InferenceProvider for HTTPProvider {
28 async fn generate(&self, prompt: &str, options: &GenerationOptions) -> Result<GenerationResponse, InferenceError> {
29 if options.hf_repo.is_some() || options.model_path.is_some() {
31 let mut request_json = json!({
33 "prompt": prompt,
34 "n_predict": options.max_tokens,
35 "temperature": options.temperature,
36 "top_p": options.top_p,
37 "stop": options.stop_sequences,
38 });
39
40 if let Some(hf_repo) = &options.hf_repo {
42 request_json = json!({
43 "prompt": prompt,
44 "n_predict": options.max_tokens,
45 "temperature": options.temperature,
46 "top_p": options.top_p,
47 "stop": options.stop_sequences,
48 "hf_repo": hf_repo,
49 });
50 }
51
52 let mut request = self.client
53 .post(&format!("{}/completion", self.inference_url))
54 .json(&request_json);
55
56 if let Some(key) = &self.api_key {
58 request = request.header("Authorization", format!("Bearer {}", key));
59 } else if let Some(token) = &self.provider_token {
60 request = request.header("Authorization", format!("Bearer {}", token));
61 }
62
63 let response = request.send().await
64 .map_err(|e| InferenceError::NetworkError(e.to_string()))?;
65
66 if !response.status().is_success() {
67 return Err(InferenceError::ServerError(format!(
68 "Server returned status: {}", response.status()
69 )));
70 }
71
72 let result: serde_json::Value = response.json().await
73 .map_err(|e| InferenceError::ServerError(format!("Failed to parse response: {}", e)))?;
74
75 let content = result["content"].as_str()
77 .ok_or_else(|| InferenceError::ServerError("Missing content in response".to_string()))?;
78 let tokens_used = result["tokens_predicted"].as_u64().unwrap_or(0) as usize;
79 let stop_reason = result["stop_type"].as_str().unwrap_or("").to_string();
80
81 Ok(GenerationResponse {
82 text: content.to_string(),
83 tokens_used,
84 finish_reason: stop_reason,
85 })
86 } else {
87 let mut request = self.client
89 .post(&format!("{}/v1/chat/completions", self.inference_url))
90 .json(&json!({
91 "model": options.model.as_deref().unwrap_or("gpt-3.5-turbo"),
92 "messages": [{"role": "user", "content": prompt}],
93 "max_tokens": options.max_tokens,
94 "temperature": options.temperature,
95 "top_p": options.top_p,
96 }));
97
98 if let Some(key) = &self.api_key {
100 request = request.header("Authorization", format!("Bearer {}", key));
101 } else if let Some(token) = &self.provider_token {
102 request = request.header("Authorization", format!("Bearer {}", token));
103 }
104
105 let response = request.send().await
106 .map_err(|e| InferenceError::NetworkError(e.to_string()))?;
107
108 if !response.status().is_success() {
109 return Err(InferenceError::ServerError(format!(
110 "Server returned status: {}", response.status()
111 )));
112 }
113
114 let result: serde_json::Value = response.json().await
115 .map_err(|e| InferenceError::ServerError(format!("Failed to parse response: {}", e)))?;
116
117 let content = result["choices"][0]["message"]["content"].as_str()
119 .ok_or_else(|| InferenceError::ServerError("Missing content in response".to_string()))?;
120 let tokens_used = result["usage"]["total_tokens"].as_u64().unwrap_or(0) as usize;
121 let finish_reason = result["choices"][0]["finish_reason"].as_str().unwrap_or("").to_string();
122
123 Ok(GenerationResponse {
124 text: content.to_string(),
125 tokens_used,
126 finish_reason,
127 })
128 }
129 }
130
131 async fn is_available(&self) -> bool {
132 match self.client.get(&format!("{}/props", self.inference_url)).send().await {
134 Ok(response) => response.status().is_success(),
135 Err(_) => false,
136 }
137 }
138}