1use super::openai::{send_compatible_streaming_request, AssistantContent};
13use crate::json_utils::merge_inplace;
14use crate::streaming::{StreamingCompletionModel, StreamingResult};
15use crate::{
16 agent::AgentBuilder,
17 completion::{self, CompletionError, CompletionRequest},
18 extractor::ExtractorBuilder,
19 json_utils,
20 providers::openai::Message,
21 OneOrMany,
22};
23use schemars::JsonSchema;
24use serde::{Deserialize, Serialize};
25use serde_json::{json, Value};
26
27const HYPERBOLIC_API_BASE_URL: &str = "https://api.hyperbolic.xyz/v1";
31
32#[derive(Clone)]
33pub struct Client {
34 base_url: String,
35 http_client: reqwest::Client,
36}
37
38impl Client {
39 pub fn new(api_key: &str) -> Self {
41 Self::from_url(api_key, HYPERBOLIC_API_BASE_URL)
42 }
43
44 pub fn from_url(api_key: &str, base_url: &str) -> Self {
46 Self {
47 base_url: base_url.to_string(),
48 http_client: reqwest::Client::builder()
49 .default_headers({
50 let mut headers = reqwest::header::HeaderMap::new();
51 headers.insert(
52 "Authorization",
53 format!("Bearer {}", api_key)
54 .parse()
55 .expect("Bearer token should parse"),
56 );
57 headers
58 })
59 .build()
60 .expect("OpenAI reqwest client should build"),
61 }
62 }
63
64 pub fn from_env() -> Self {
67 let api_key = std::env::var("HYPERBOLIC_API_KEY").expect("HYPERBOLIC_API_KEY not set");
68 Self::new(&api_key)
69 }
70
71 fn post(&self, path: &str) -> reqwest::RequestBuilder {
72 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
73 self.http_client.post(url)
74 }
75
76 pub fn completion_model(&self, model: &str) -> CompletionModel {
88 CompletionModel::new(self.clone(), model)
89 }
90
91 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
106 AgentBuilder::new(self.completion_model(model))
107 }
108
109 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
111 &self,
112 model: &str,
113 ) -> ExtractorBuilder<T, CompletionModel> {
114 ExtractorBuilder::new(self.completion_model(model))
115 }
116}
117
118#[derive(Debug, Deserialize)]
119struct ApiErrorResponse {
120 message: String,
121}
122
123#[derive(Debug, Deserialize)]
124#[serde(untagged)]
125enum ApiResponse<T> {
126 Ok(T),
127 Err(ApiErrorResponse),
128}
129
130#[derive(Debug, Deserialize)]
131pub struct EmbeddingData {
132 pub object: String,
133 pub embedding: Vec<f64>,
134 pub index: usize,
135}
136
137#[derive(Clone, Debug, Deserialize)]
138pub struct Usage {
139 pub prompt_tokens: usize,
140 pub total_tokens: usize,
141}
142
143impl std::fmt::Display for Usage {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 write!(
146 f,
147 "Prompt tokens: {} Total tokens: {}",
148 self.prompt_tokens, self.total_tokens
149 )
150 }
151}
152
153pub const LLAMA_3_1_8B: &str = "meta-llama/Meta-Llama-3.1-8B-Instruct";
158pub const LLAMA_3_3_70B: &str = "meta-llama/Llama-3.3-70B-Instruct";
160pub const LLAMA_3_1_70B: &str = "meta-llama/Meta-Llama-3.1-70B-Instruct";
162pub const LLAMA_3_70B: &str = "meta-llama/Meta-Llama-3-70B-Instruct";
164pub const HERMES_3_70B: &str = "NousResearch/Hermes-3-Llama-3.1-70b";
166pub const DEEPSEEK_2_5: &str = "deepseek-ai/DeepSeek-V2.5";
168pub const QWEN_2_5_72B: &str = "Qwen/Qwen2.5-72B-Instruct";
170pub const LLAMA_3_2_3B: &str = "meta-llama/Llama-3.2-3B-Instruct";
172pub const QWEN_2_5_CODER_32B: &str = "Qwen/Qwen2.5-Coder-32B-Instruct";
174pub const QWEN_QWQ_PREVIEW_32B: &str = "Qwen/QwQ-32B-Preview";
176pub const DEEPSEEK_R1_ZERO: &str = "deepseek-ai/DeepSeek-R1-Zero";
178pub const DEEPSEEK_R1: &str = "deepseek-ai/DeepSeek-R1";
180
181#[derive(Debug, Deserialize)]
185pub struct CompletionResponse {
186 pub id: String,
187 pub object: String,
188 pub created: u64,
189 pub model: String,
190 pub choices: Vec<Choice>,
191 pub usage: Option<Usage>,
192}
193
194impl From<ApiErrorResponse> for CompletionError {
195 fn from(err: ApiErrorResponse) -> Self {
196 CompletionError::ProviderError(err.message)
197 }
198}
199
200impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
201 type Error = CompletionError;
202
203 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
204 let choice = response.choices.first().ok_or_else(|| {
205 CompletionError::ResponseError("Response contained no choices".to_owned())
206 })?;
207
208 let content = match &choice.message {
209 Message::Assistant {
210 content,
211 tool_calls,
212 ..
213 } => {
214 let mut content = content
215 .iter()
216 .map(|c| match c {
217 AssistantContent::Text { text } => completion::AssistantContent::text(text),
218 AssistantContent::Refusal { refusal } => {
219 completion::AssistantContent::text(refusal)
220 }
221 })
222 .collect::<Vec<_>>();
223
224 content.extend(
225 tool_calls
226 .iter()
227 .map(|call| {
228 completion::AssistantContent::tool_call(
229 &call.function.name,
230 &call.function.name,
231 call.function.arguments.clone(),
232 )
233 })
234 .collect::<Vec<_>>(),
235 );
236 Ok(content)
237 }
238 _ => Err(CompletionError::ResponseError(
239 "Response did not contain a valid message or tool call".into(),
240 )),
241 }?;
242
243 let choice = OneOrMany::many(content).map_err(|_| {
244 CompletionError::ResponseError(
245 "Response contained no message or tool call (empty)".to_owned(),
246 )
247 })?;
248
249 Ok(completion::CompletionResponse {
250 choice,
251 raw_response: response,
252 })
253 }
254}
255
256#[derive(Debug, Deserialize)]
257pub struct Choice {
258 pub index: usize,
259 pub message: Message,
260 pub finish_reason: String,
261}
262
263#[derive(Clone)]
264pub struct CompletionModel {
265 client: Client,
266 pub model: String,
268}
269
270impl CompletionModel {
271 pub(crate) fn create_completion_request(
272 &self,
273 completion_request: CompletionRequest,
274 ) -> Result<Value, CompletionError> {
275 let mut full_history: Vec<Message> = match &completion_request.preamble {
277 Some(preamble) => vec![Message::system(preamble)],
278 None => vec![],
279 };
280
281 let prompt: Vec<Message> = completion_request.prompt_with_context().try_into()?;
283
284 let chat_history: Vec<Message> = completion_request
286 .chat_history
287 .into_iter()
288 .map(|message| message.try_into())
289 .collect::<Result<Vec<Vec<Message>>, _>>()?
290 .into_iter()
291 .flatten()
292 .collect();
293
294 full_history.extend(chat_history);
296 full_history.extend(prompt);
297
298 let request = json!({
299 "model": self.model,
300 "messages": full_history,
301 "temperature": completion_request.temperature,
302 });
303
304 let request = if let Some(params) = completion_request.additional_params {
305 json_utils::merge(request, params)
306 } else {
307 request
308 };
309
310 Ok(request)
311 }
312}
313
314impl CompletionModel {
315 pub fn new(client: Client, model: &str) -> Self {
316 Self {
317 client,
318 model: model.to_string(),
319 }
320 }
321}
322
323impl completion::CompletionModel for CompletionModel {
324 type Response = CompletionResponse;
325
326 #[cfg_attr(feature = "worker", worker::send)]
327 async fn completion(
328 &self,
329 completion_request: CompletionRequest,
330 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
331 let request = self.create_completion_request(completion_request)?;
332
333 let response = self
334 .client
335 .post("/chat/completions")
336 .json(&request)
337 .send()
338 .await?;
339
340 if response.status().is_success() {
341 match response.json::<ApiResponse<CompletionResponse>>().await? {
342 ApiResponse::Ok(response) => {
343 tracing::info!(target: "rig",
344 "Hyperbolic completion token usage: {:?}",
345 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
346 );
347
348 response.try_into()
349 }
350 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
351 }
352 } else {
353 Err(CompletionError::ProviderError(response.text().await?))
354 }
355 }
356}
357
358impl StreamingCompletionModel for CompletionModel {
359 async fn stream(
360 &self,
361 completion_request: CompletionRequest,
362 ) -> Result<StreamingResult, CompletionError> {
363 let mut request = self.create_completion_request(completion_request)?;
364
365 merge_inplace(&mut request, json!({"stream": true}));
366
367 let builder = self.client.post("/chat/completions").json(&request);
368
369 send_compatible_streaming_request(builder).await
370 }
371}