manx_cli/rag/providers/
custom.rs

1use anyhow::{anyhow, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::{EmbeddingProvider as ProviderTrait, ProviderInfo};
7
8/// Custom endpoint embedding provider
9pub struct CustomProvider {
10    client: Client,
11    endpoint_url: String,
12    api_key: Option<String>,
13    dimension: Option<usize>, // Cached dimension
14}
15
16#[derive(Serialize)]
17struct CustomEmbeddingRequest {
18    text: String,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    model: Option<String>,
21}
22
23#[derive(Deserialize)]
24struct CustomEmbeddingResponse {
25    embedding: Vec<f32>,
26    #[serde(default)]
27    dimension: Option<usize>,
28}
29
30impl CustomProvider {
31    /// Create a new custom endpoint provider
32    pub fn new(endpoint_url: String, api_key: Option<String>) -> Self {
33        let client = Client::builder()
34            .timeout(std::time::Duration::from_secs(30))
35            .build()
36            .unwrap();
37
38        Self {
39            client,
40            endpoint_url,
41            api_key,
42            dimension: None,
43        }
44    }
45
46    /// Detect dimension by making a test API call
47    #[allow(dead_code)]
48    pub async fn detect_dimension(&mut self) -> Result<usize> {
49        if let Some(dim) = self.dimension {
50            return Ok(dim);
51        }
52
53        log::info!(
54            "Detecting embedding dimension for custom endpoint: {}",
55            self.endpoint_url
56        );
57
58        let test_embedding = self.call_api("test").await?;
59        let dimension = test_embedding.len();
60
61        self.dimension = Some(dimension);
62        log::info!(
63            "Detected dimension: {} for endpoint {}",
64            dimension,
65            self.endpoint_url
66        );
67
68        Ok(dimension)
69    }
70
71    /// Make API call to custom endpoint
72    async fn call_api(&self, text: &str) -> Result<Vec<f32>> {
73        let request = CustomEmbeddingRequest {
74            text: text.to_string(),
75            model: None,
76        };
77
78        let mut request_builder = self
79            .client
80            .post(&self.endpoint_url)
81            .header("Content-Type", "application/json")
82            .json(&request);
83
84        // Add API key if provided
85        if let Some(ref api_key) = self.api_key {
86            request_builder =
87                request_builder.header("Authorization", format!("Bearer {}", api_key));
88        }
89
90        let response = request_builder.send().await?;
91
92        let status = response.status();
93        if !status.is_success() {
94            let error_text = response.text().await.unwrap_or_default();
95            return Err(anyhow!(
96                "Custom endpoint error: HTTP {} - {}",
97                status,
98                error_text
99            ));
100        }
101
102        let embedding_response: CustomEmbeddingResponse = response.json().await?;
103
104        if embedding_response.embedding.is_empty() {
105            return Err(anyhow!("No embeddings returned from custom endpoint"));
106        }
107
108        // Cache dimension if provided in response
109        if let Some(dim) = embedding_response.dimension {
110            if self.dimension.is_none() {
111                // Note: This is a mutable operation but we're in an immutable context
112                // In practice, this would need to be handled differently
113                log::info!("Custom endpoint reported dimension: {}", dim);
114            }
115        }
116
117        Ok(embedding_response.embedding)
118    }
119
120    /// Health check for custom endpoint
121    pub async fn check_endpoint(&self) -> Result<()> {
122        // Try a simple GET request first to see if endpoint is reachable
123        let response = self
124            .client
125            .get(&self.endpoint_url)
126            .send()
127            .await
128            .map_err(|e| {
129                anyhow!(
130                    "Failed to connect to custom endpoint {}: {}",
131                    self.endpoint_url,
132                    e
133                )
134            })?;
135
136        // Accept any non-server-error response for basic connectivity
137        if response.status().as_u16() >= 500 {
138            return Err(anyhow!(
139                "Custom endpoint returned server error: HTTP {}",
140                response.status()
141            ));
142        }
143
144        Ok(())
145    }
146}
147
148#[async_trait]
149impl ProviderTrait for CustomProvider {
150    async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
151        if text.trim().is_empty() {
152            return Err(anyhow!("Cannot embed empty text"));
153        }
154
155        self.call_api(text).await
156    }
157
158    async fn get_dimension(&self) -> Result<usize> {
159        if let Some(dim) = self.dimension {
160            Ok(dim)
161        } else {
162            Err(anyhow!("Dimension not known for custom endpoint {}. Use 'manx embedding test' to detect it.", self.endpoint_url))
163        }
164    }
165
166    async fn health_check(&self) -> Result<()> {
167        self.check_endpoint().await?;
168        self.call_api("test").await.map(|_| ())
169    }
170
171    fn get_info(&self) -> ProviderInfo {
172        ProviderInfo {
173            name: "Custom Endpoint".to_string(),
174            provider_type: "custom".to_string(),
175            model_name: None,
176            description: format!("Custom embedding endpoint: {}", self.endpoint_url),
177            max_input_length: None, // Unknown
178        }
179    }
180}