manx_cli/rag/providers/
custom.rs1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::{EmbeddingProvider as ProviderTrait, ProviderInfo};
7
8pub struct CustomProvider {
10 client: Client,
11 endpoint_url: String,
12 api_key: Option<String>,
13 dimension: Option<usize>, }
15
16#[derive(Serialize)]
17struct CustomEmbeddingRequest {
18 text: String,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 model: Option<String>,
21}
22
23#[derive(Deserialize)]
24struct CustomEmbeddingResponse {
25 embedding: Vec<f32>,
26 #[serde(default)]
27 dimension: Option<usize>,
28}
29
30impl CustomProvider {
31 pub fn new(endpoint_url: String, api_key: Option<String>) -> Self {
33 let client = Client::builder()
34 .timeout(std::time::Duration::from_secs(30))
35 .build()
36 .unwrap();
37
38 Self {
39 client,
40 endpoint_url,
41 api_key,
42 dimension: None,
43 }
44 }
45
46 #[allow(dead_code)]
48 pub async fn detect_dimension(&mut self) -> Result<usize> {
49 if let Some(dim) = self.dimension {
50 return Ok(dim);
51 }
52
53 log::info!(
54 "Detecting embedding dimension for custom endpoint: {}",
55 self.endpoint_url
56 );
57
58 let test_embedding = self.call_api("test").await?;
59 let dimension = test_embedding.len();
60
61 self.dimension = Some(dimension);
62 log::info!(
63 "Detected dimension: {} for endpoint {}",
64 dimension,
65 self.endpoint_url
66 );
67
68 Ok(dimension)
69 }
70
71 async fn call_api(&self, text: &str) -> Result<Vec<f32>> {
73 let request = CustomEmbeddingRequest {
74 text: text.to_string(),
75 model: None,
76 };
77
78 let mut request_builder = self
79 .client
80 .post(&self.endpoint_url)
81 .header("Content-Type", "application/json")
82 .json(&request);
83
84 if let Some(ref api_key) = self.api_key {
86 request_builder =
87 request_builder.header("Authorization", format!("Bearer {}", api_key));
88 }
89
90 let response = request_builder.send().await?;
91
92 let status = response.status();
93 if !status.is_success() {
94 let error_text = response.text().await.unwrap_or_default();
95 return Err(anyhow!(
96 "Custom endpoint error: HTTP {} - {}",
97 status,
98 error_text
99 ));
100 }
101
102 let embedding_response: CustomEmbeddingResponse = response.json().await?;
103
104 if embedding_response.embedding.is_empty() {
105 return Err(anyhow!("No embeddings returned from custom endpoint"));
106 }
107
108 if let Some(dim) = embedding_response.dimension {
110 if self.dimension.is_none() {
111 log::info!("Custom endpoint reported dimension: {}", dim);
114 }
115 }
116
117 Ok(embedding_response.embedding)
118 }
119
120 pub async fn check_endpoint(&self) -> Result<()> {
122 let response = self
124 .client
125 .get(&self.endpoint_url)
126 .send()
127 .await
128 .map_err(|e| {
129 anyhow!(
130 "Failed to connect to custom endpoint {}: {}",
131 self.endpoint_url,
132 e
133 )
134 })?;
135
136 if response.status().as_u16() >= 500 {
138 return Err(anyhow!(
139 "Custom endpoint returned server error: HTTP {}",
140 response.status()
141 ));
142 }
143
144 Ok(())
145 }
146}
147
148#[async_trait]
149impl ProviderTrait for CustomProvider {
150 async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
151 if text.trim().is_empty() {
152 return Err(anyhow!("Cannot embed empty text"));
153 }
154
155 self.call_api(text).await
156 }
157
158 async fn get_dimension(&self) -> Result<usize> {
159 if let Some(dim) = self.dimension {
160 Ok(dim)
161 } else {
162 Err(anyhow!("Dimension not known for custom endpoint {}. Use 'manx embedding test' to detect it.", self.endpoint_url))
163 }
164 }
165
166 async fn health_check(&self) -> Result<()> {
167 self.check_endpoint().await?;
168 self.call_api("test").await.map(|_| ())
169 }
170
171 fn get_info(&self) -> ProviderInfo {
172 ProviderInfo {
173 name: "Custom Endpoint".to_string(),
174 provider_type: "custom".to_string(),
175 model_name: None,
176 description: format!("Custom embedding endpoint: {}", self.endpoint_url),
177 max_input_length: None, }
179 }
180}