manx_cli/rag/providers/
huggingface.rs1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::Serialize;
5
6use super::{EmbeddingProvider as ProviderTrait, ProviderInfo};
7
8pub struct HuggingFaceProvider {
10 client: Client,
11 api_key: String,
12 model: String,
13 dimension: Option<usize>, }
15
16#[derive(Serialize)]
17struct HfEmbeddingRequest {
18 inputs: String,
19 options: HfOptions,
20}
21
22#[derive(Serialize)]
23struct HfOptions {
24 wait_for_model: bool,
25}
26
27impl HuggingFaceProvider {
28 pub fn new(api_key: String, model: String) -> Self {
30 let client = Client::builder()
31 .timeout(std::time::Duration::from_secs(60)) .build()
33 .unwrap();
34
35 Self {
36 client,
37 api_key,
38 model,
39 dimension: None,
40 }
41 }
42
43 #[allow(dead_code)]
45 pub async fn detect_dimension(&mut self) -> Result<usize> {
46 if let Some(dim) = self.dimension {
47 return Ok(dim);
48 }
49
50 log::info!(
51 "Detecting embedding dimension for HuggingFace model: {}",
52 self.model
53 );
54
55 let test_embedding = self.call_api("test").await?;
56 let dimension = test_embedding.len();
57
58 self.dimension = Some(dimension);
59 log::info!("Detected dimension: {} for model {}", dimension, self.model);
60
61 Ok(dimension)
62 }
63
64 async fn call_api(&self, text: &str) -> Result<Vec<f32>> {
66 let request = HfEmbeddingRequest {
67 inputs: text.to_string(),
68 options: HfOptions {
69 wait_for_model: true,
70 },
71 };
72
73 let url = format!("https://api-inference.huggingface.co/models/{}", self.model);
74
75 let response = self
76 .client
77 .post(&url)
78 .header("Authorization", format!("Bearer {}", self.api_key))
79 .header("Content-Type", "application/json")
80 .json(&request)
81 .send()
82 .await?;
83
84 let status = response.status();
85 if !status.is_success() {
86 let error_text = response.text().await.unwrap_or_default();
87 return Err(anyhow!(
88 "HuggingFace API error: HTTP {} - {}",
89 status,
90 error_text
91 ));
92 }
93
94 let embeddings: Vec<f32> = response.json().await?;
96
97 if embeddings.is_empty() {
98 return Err(anyhow!("No embeddings returned from HuggingFace API"));
99 }
100
101 Ok(embeddings)
102 }
103
104 pub fn get_model_info(model: &str) -> (Option<usize>, usize) {
106 match model {
107 "sentence-transformers/all-MiniLM-L6-v2" => (Some(384), 512),
108 "sentence-transformers/all-mpnet-base-v2" => (Some(768), 512),
109 "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" => (Some(384), 512),
110 "BAAI/bge-small-en-v1.5" => (Some(384), 512),
111 "BAAI/bge-base-en-v1.5" => (Some(768), 512),
112 "BAAI/bge-large-en-v1.5" => (Some(1024), 512),
113 _ => (None, 512), }
115 }
116}
117
118#[async_trait]
119impl ProviderTrait for HuggingFaceProvider {
120 async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
121 if text.trim().is_empty() {
122 return Err(anyhow!("Cannot embed empty text"));
123 }
124
125 let (_, max_chars) = Self::get_model_info(&self.model);
127 let truncated_text = if text.len() > max_chars * 4 {
128 &text[..max_chars * 4]
130 } else {
131 text
132 };
133
134 self.call_api(truncated_text).await
135 }
136
137 async fn get_dimension(&self) -> Result<usize> {
138 if let Some(dim) = self.dimension {
139 Ok(dim)
140 } else {
141 let (known_dim, _) = Self::get_model_info(&self.model);
143 if let Some(dim) = known_dim {
144 Ok(dim)
145 } else {
146 Err(anyhow!(
148 "Dimension not known for model {}. Use 'manx embedding test' to detect it.",
149 self.model
150 ))
151 }
152 }
153 }
154
155 async fn health_check(&self) -> Result<()> {
156 self.call_api("test").await.map(|_| ())
157 }
158
159 fn get_info(&self) -> ProviderInfo {
160 let (_, max_length) = Self::get_model_info(&self.model);
161
162 ProviderInfo {
163 name: "HuggingFace Inference API".to_string(),
164 provider_type: "huggingface".to_string(),
165 model_name: Some(self.model.clone()),
166 description: format!("HuggingFace model: {}", self.model),
167 max_input_length: Some(max_length),
168 }
169 }
170
171 fn as_any(&self) -> &dyn std::any::Any {
172 self
173 }
174}