Skip to main content

manx_cli/rag/providers/
huggingface.rs

1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::Serialize;
5
6use super::{EmbeddingProvider as ProviderTrait, ProviderInfo};
7
8/// HuggingFace Inference API embedding provider
9pub struct HuggingFaceProvider {
10    client: Client,
11    api_key: String,
12    model: String,
13    dimension: Option<usize>, // Cached dimension
14}
15
16#[derive(Serialize)]
17struct HfEmbeddingRequest {
18    inputs: String,
19    options: HfOptions,
20}
21
22#[derive(Serialize)]
23struct HfOptions {
24    wait_for_model: bool,
25}
26
27impl HuggingFaceProvider {
28    /// Create a new HuggingFace provider
29    pub fn new(api_key: String, model: String) -> Self {
30        let client = Client::builder()
31            .timeout(std::time::Duration::from_secs(60)) // HF can be slower
32            .build()
33            .unwrap();
34
35        Self {
36            client,
37            api_key,
38            model,
39            dimension: None,
40        }
41    }
42
43    /// Detect dimension by making a test API call
44    #[allow(dead_code)]
45    pub async fn detect_dimension(&mut self) -> Result<usize> {
46        if let Some(dim) = self.dimension {
47            return Ok(dim);
48        }
49
50        log::info!(
51            "Detecting embedding dimension for HuggingFace model: {}",
52            self.model
53        );
54
55        let test_embedding = self.call_api("test").await?;
56        let dimension = test_embedding.len();
57
58        self.dimension = Some(dimension);
59        log::info!("Detected dimension: {} for model {}", dimension, self.model);
60
61        Ok(dimension)
62    }
63
64    /// Make API call to HuggingFace Inference API
65    async fn call_api(&self, text: &str) -> Result<Vec<f32>> {
66        let request = HfEmbeddingRequest {
67            inputs: text.to_string(),
68            options: HfOptions {
69                wait_for_model: true,
70            },
71        };
72
73        let url = format!("https://api-inference.huggingface.co/models/{}", self.model);
74
75        let response = self
76            .client
77            .post(&url)
78            .header("Authorization", format!("Bearer {}", self.api_key))
79            .header("Content-Type", "application/json")
80            .json(&request)
81            .send()
82            .await?;
83
84        let status = response.status();
85        if !status.is_success() {
86            let error_text = response.text().await.unwrap_or_default();
87            return Err(anyhow!(
88                "HuggingFace API error: HTTP {} - {}",
89                status,
90                error_text
91            ));
92        }
93
94        // HuggingFace returns embeddings as a flat array
95        let embeddings: Vec<f32> = response.json().await?;
96
97        if embeddings.is_empty() {
98            return Err(anyhow!("No embeddings returned from HuggingFace API"));
99        }
100
101        Ok(embeddings)
102    }
103
104    /// Get common HuggingFace model information (dimension, max_length)
105    pub fn get_model_info(model: &str) -> (Option<usize>, usize) {
106        match model {
107            "sentence-transformers/all-MiniLM-L6-v2" => (Some(384), 512),
108            "sentence-transformers/all-mpnet-base-v2" => (Some(768), 512),
109            "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" => (Some(384), 512),
110            "BAAI/bge-small-en-v1.5" => (Some(384), 512),
111            "BAAI/bge-base-en-v1.5" => (Some(768), 512),
112            "BAAI/bge-large-en-v1.5" => (Some(1024), 512),
113            _ => (None, 512), // Unknown model, will detect dynamically
114        }
115    }
116}
117
118#[async_trait]
119impl ProviderTrait for HuggingFaceProvider {
120    async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
121        if text.trim().is_empty() {
122            return Err(anyhow!("Cannot embed empty text"));
123        }
124
125        // Truncate text if too long
126        let (_, max_chars) = Self::get_model_info(&self.model);
127        let truncated_text = if text.len() > max_chars * 4 {
128            // Rough token approximation
129            &text[..max_chars * 4]
130        } else {
131            text
132        };
133
134        self.call_api(truncated_text).await
135    }
136
137    async fn get_dimension(&self) -> Result<usize> {
138        if let Some(dim) = self.dimension {
139            Ok(dim)
140        } else {
141            // Try to use known dimensions for common models
142            let (known_dim, _) = Self::get_model_info(&self.model);
143            if let Some(dim) = known_dim {
144                Ok(dim)
145            } else {
146                // Need to detect dynamically
147                Err(anyhow!(
148                    "Dimension not known for model {}. Use 'manx embedding test' to detect it.",
149                    self.model
150                ))
151            }
152        }
153    }
154
155    async fn health_check(&self) -> Result<()> {
156        self.call_api("test").await.map(|_| ())
157    }
158
159    fn get_info(&self) -> ProviderInfo {
160        let (_, max_length) = Self::get_model_info(&self.model);
161
162        ProviderInfo {
163            name: "HuggingFace Inference API".to_string(),
164            provider_type: "huggingface".to_string(),
165            model_name: Some(self.model.clone()),
166            description: format!("HuggingFace model: {}", self.model),
167            max_input_length: Some(max_length),
168        }
169    }
170
171    fn as_any(&self) -> &dyn std::any::Any {
172        self
173    }
174}