Skip to main content

hoist_client/
client.rs

1//! Azure Search REST API client
2
3use std::time::Duration;
4
5use reqwest::{Client, Method, StatusCode};
6use serde_json::Value;
7use tracing::{debug, instrument, warn};
8
9use hoist_core::resources::ResourceKind;
10use hoist_core::Config;
11
12use crate::auth::{get_auth_provider, AuthProvider};
13use crate::error::ClientError;
14
15/// Maximum number of retry attempts for retryable errors
16const MAX_RETRIES: u32 = 3;
17
18/// Initial backoff delay in seconds
19const INITIAL_BACKOFF_SECS: u64 = 1;
20
21/// Calculate the backoff duration for a given retry attempt.
22///
23/// For `RateLimited` errors with a `retry_after` value, that value is used directly.
24/// For other retryable errors, exponential backoff is applied: 1s, 2s, 4s, etc.
25fn retry_delay(error: &ClientError, attempt: u32) -> Duration {
26    match error {
27        ClientError::RateLimited { retry_after } => Duration::from_secs(*retry_after),
28        _ => Duration::from_secs(INITIAL_BACKOFF_SECS * 2u64.pow(attempt)),
29    }
30}
31
32/// API version for agentic retrieval resources (KnowledgeBase, KnowledgeSource).
33///
34/// Pinned to `2025-08-01-preview` because `2025-11-01-preview` introduced breaking
35/// schema changes: fields like `language`, `production_family`, `embeddingModel`,
36/// and `chatCompletionModel` were reorganized into a nested `ingestionParameters`
37/// object, and `sourceDataSelect` was renamed to `sourceDataFields`. Knowledge
38/// sources created with the older schema cannot be updated through `2025-11-01-preview`
39/// without migration (creating new objects with new names).
40///
41/// See: https://learn.microsoft.com/en-us/azure/search/agentic-retrieval-how-to-migrate
42const AGENTIC_RETRIEVAL_API_VERSION: &str = "2025-08-01-preview";
43
44/// Azure Search API client
45pub struct AzureSearchClient {
46    http: Client,
47    auth: Box<dyn AuthProvider>,
48    base_url: String,
49    preview_api_version: String,
50}
51
52impl AzureSearchClient {
53    /// Create a new client from configuration
54    pub fn new(config: &Config) -> Result<Self, ClientError> {
55        let auth = get_auth_provider()?;
56        let http = Client::builder()
57            .timeout(std::time::Duration::from_secs(30))
58            .build()?;
59
60        Ok(Self {
61            http,
62            auth,
63            base_url: config.service_url(),
64            preview_api_version: config.api_version_for(true).to_string(),
65        })
66    }
67
68    /// Create a client pointing to a different server, using the same auth and API versions
69    pub fn new_for_server(config: &Config, server_name: &str) -> Result<Self, ClientError> {
70        let auth = get_auth_provider()?;
71        let http = Client::builder()
72            .timeout(std::time::Duration::from_secs(30))
73            .build()?;
74
75        Ok(Self {
76            http,
77            auth,
78            base_url: format!("https://{}.search.windows.net", server_name),
79            preview_api_version: config.api_version_for(true).to_string(),
80        })
81    }
82
83    /// Create with a custom auth provider (for testing)
84    pub fn with_auth(
85        base_url: String,
86        preview_api_version: String,
87        auth: Box<dyn AuthProvider>,
88    ) -> Result<Self, ClientError> {
89        let http = Client::builder()
90            .timeout(std::time::Duration::from_secs(30))
91            .build()?;
92
93        Ok(Self {
94            http,
95            auth,
96            base_url,
97            preview_api_version,
98        })
99    }
100
101    /// Get the API version to use for a resource kind.
102    ///
103    /// - KnowledgeBase / KnowledgeSource use `AGENTIC_RETRIEVAL_API_VERSION`
104    ///   (`2025-08-01-preview`) because `2025-11-01-preview` has breaking schema changes.
105    /// - All other resources use the preview API version — it is a superset of the
106    ///   stable API and avoids failures when stable resources contain preview-only
107    ///   features (e.g. a skillset with ChatCompletionSkill).
108    fn api_version_for(&self, kind: ResourceKind) -> &str {
109        match kind {
110            ResourceKind::KnowledgeBase | ResourceKind::KnowledgeSource => {
111                AGENTIC_RETRIEVAL_API_VERSION
112            }
113            _ => &self.preview_api_version,
114        }
115    }
116
117    /// Build URL for a resource collection
118    fn collection_url(&self, kind: ResourceKind) -> String {
119        format!(
120            "{}/{}?api-version={}",
121            self.base_url,
122            kind.api_path(),
123            self.api_version_for(kind)
124        )
125    }
126
127    /// Build URL for a specific resource
128    fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
129        format!(
130            "{}/{}/{}?api-version={}",
131            self.base_url,
132            kind.api_path(),
133            name,
134            self.api_version_for(kind)
135        )
136    }
137
138    /// Execute an HTTP request
139    async fn request(
140        &self,
141        method: Method,
142        url: &str,
143        body: Option<&Value>,
144    ) -> Result<Option<Value>, ClientError> {
145        let token = self.auth.get_token()?;
146
147        let mut request = self
148            .http
149            .request(method.clone(), url)
150            .header("Authorization", format!("Bearer {}", token))
151            .header("Content-Type", "application/json");
152
153        if let Some(json) = body {
154            request = request.json(json);
155        }
156
157        debug!("Request: {} {}", method, url);
158        let response = request.send().await?;
159        let status = response.status();
160
161        if status == StatusCode::NO_CONTENT {
162            return Ok(None);
163        }
164
165        let body = response.text().await?;
166
167        if status.is_success() {
168            if body.is_empty() {
169                Ok(None)
170            } else {
171                let value: Value = serde_json::from_str(&body)?;
172                Ok(Some(value))
173            }
174        } else {
175            match status {
176                StatusCode::NOT_FOUND => Err(ClientError::NotFound {
177                    kind: "resource".to_string(),
178                    name: url.to_string(),
179                }),
180                StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
181                    kind: "resource".to_string(),
182                    name: url.to_string(),
183                }),
184                StatusCode::TOO_MANY_REQUESTS => {
185                    let retry_after = 60; // Default retry time
186                    Err(ClientError::RateLimited { retry_after })
187                }
188                StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
189                _ => Err(ClientError::from_response_with_url(
190                    status.as_u16(),
191                    &body,
192                    Some(url),
193                )),
194            }
195        }
196    }
197
198    /// Execute an HTTP request with retry logic for transient errors.
199    ///
200    /// Retries up to [`MAX_RETRIES`] times for retryable errors (429 and 503).
201    /// Uses exponential backoff (1s, 2s, 4s) for 503 errors and respects the
202    /// `retry_after` value for 429 rate-limiting errors.
203    async fn request_with_retry(
204        &self,
205        method: Method,
206        url: &str,
207        body: Option<&Value>,
208    ) -> Result<Option<Value>, ClientError> {
209        let mut attempt = 0u32;
210        loop {
211            match self.request(method.clone(), url, body).await {
212                Ok(value) => return Ok(value),
213                Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
214                    let delay = retry_delay(&err, attempt);
215                    warn!(
216                        "Request {} {} failed (attempt {}/{}): {}. Retrying in {:?}",
217                        method,
218                        url,
219                        attempt + 1,
220                        MAX_RETRIES + 1,
221                        err,
222                        delay,
223                    );
224                    tokio::time::sleep(delay).await;
225                    attempt += 1;
226                }
227                Err(err) => return Err(err),
228            }
229        }
230    }
231
232    /// List all resources of a given kind
233    #[instrument(skip(self))]
234    pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
235        let url = self.collection_url(kind);
236        let response = self.request_with_retry(Method::GET, &url, None).await?;
237
238        match response {
239            Some(value) => {
240                // Azure returns { "value": [...] }
241                let items = value
242                    .get("value")
243                    .and_then(|v| v.as_array())
244                    .cloned()
245                    .unwrap_or_default();
246                Ok(items)
247            }
248            None => Ok(Vec::new()),
249        }
250    }
251
252    /// Get a specific resource
253    #[instrument(skip(self))]
254    pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
255        let url = self.resource_url(kind, name);
256        let response = self.request_with_retry(Method::GET, &url, None).await?;
257
258        response.ok_or_else(|| ClientError::NotFound {
259            kind: kind.display_name().to_string(),
260            name: name.to_string(),
261        })
262    }
263
264    /// Create or update a resource
265    ///
266    /// Returns the response body if the API returns one. Some APIs (especially
267    /// preview endpoints like Knowledge Sources) return 204 No Content on
268    /// successful update, which yields `Ok(None)`.
269    #[instrument(skip(self, definition))]
270    pub async fn create_or_update(
271        &self,
272        kind: ResourceKind,
273        name: &str,
274        definition: &Value,
275    ) -> Result<Option<Value>, ClientError> {
276        let url = self.resource_url(kind, name);
277        self.request_with_retry(Method::PUT, &url, Some(definition))
278            .await
279    }
280
281    /// Delete a resource
282    #[instrument(skip(self))]
283    pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
284        let url = self.resource_url(kind, name);
285        self.request_with_retry(Method::DELETE, &url, None).await?;
286        Ok(())
287    }
288
289    /// Check if a resource exists
290    pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
291        match self.get(kind, name).await {
292            Ok(_) => Ok(true),
293            Err(ClientError::NotFound { .. }) => Ok(false),
294            Err(e) => Err(e),
295        }
296    }
297
298    /// Get the authentication method being used
299    pub fn auth_method(&self) -> &'static str {
300        self.auth.method_name()
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::auth::{AuthError, AuthProvider};
308
309    struct FakeAuth;
310    impl AuthProvider for FakeAuth {
311        fn get_token(&self) -> Result<String, AuthError> {
312            Ok("fake-token".to_string())
313        }
314        fn method_name(&self) -> &'static str {
315            "Fake"
316        }
317    }
318
319    fn make_client() -> AzureSearchClient {
320        AzureSearchClient::with_auth(
321            "https://test-svc.search.windows.net".to_string(),
322            "2025-11-01-preview".to_string(),
323            Box::new(FakeAuth),
324        )
325        .unwrap()
326    }
327
328    #[test]
329    fn test_collection_url_uses_preview_version() {
330        let client = make_client();
331        let url = client.collection_url(ResourceKind::Index);
332        assert_eq!(
333            url,
334            "https://test-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
335        );
336    }
337
338    #[test]
339    fn test_collection_url_knowledge_base_uses_agentic_version() {
340        let client = make_client();
341        let url = client.collection_url(ResourceKind::KnowledgeBase);
342        assert_eq!(
343            url,
344            "https://test-svc.search.windows.net/knowledgebases?api-version=2025-08-01-preview"
345        );
346    }
347
348    #[test]
349    fn test_collection_url_knowledge_source_uses_agentic_version() {
350        let client = make_client();
351        let url = client.collection_url(ResourceKind::KnowledgeSource);
352        assert_eq!(
353            url,
354            "https://test-svc.search.windows.net/knowledgesources?api-version=2025-08-01-preview"
355        );
356    }
357
358    #[test]
359    fn test_resource_url_uses_preview_version() {
360        let client = make_client();
361        let url = client.resource_url(ResourceKind::Index, "my-index");
362        assert_eq!(
363            url,
364            "https://test-svc.search.windows.net/indexes/my-index?api-version=2025-11-01-preview"
365        );
366    }
367
368    #[test]
369    fn test_resource_url_knowledge_base_uses_agentic_version() {
370        let client = make_client();
371        let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
372        assert_eq!(
373            url,
374            "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-08-01-preview"
375        );
376    }
377
378    #[test]
379    fn test_new_for_server_produces_correct_base_url() {
380        // We can't easily test new_for_server directly since it calls get_auth_provider,
381        // but we can verify the URL format through with_auth
382        let client = AzureSearchClient::with_auth(
383            "https://other-svc.search.windows.net".to_string(),
384            "2025-11-01-preview".to_string(),
385            Box::new(FakeAuth),
386        )
387        .unwrap();
388        let url = client.collection_url(ResourceKind::Index);
389        assert_eq!(
390            url,
391            "https://other-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
392        );
393    }
394
395    #[test]
396    fn test_all_kinds_use_expected_api_version() {
397        let client = make_client();
398        for kind in ResourceKind::all() {
399            let url = client.collection_url(*kind);
400            match kind {
401                ResourceKind::KnowledgeBase | ResourceKind::KnowledgeSource => {
402                    assert!(
403                        url.contains("2025-08-01-preview"),
404                        "{:?} should use agentic retrieval API version, got: {}",
405                        kind,
406                        url
407                    );
408                }
409                _ => {
410                    assert!(
411                        url.contains("2025-11-01-preview"),
412                        "{:?} should use preview API version, got: {}",
413                        kind,
414                        url
415                    );
416                }
417            }
418        }
419    }
420
421    #[test]
422    fn test_retry_delay_exponential_backoff_attempt_0() {
423        let err = ClientError::ServiceUnavailable("down".to_string());
424        let delay = retry_delay(&err, 0);
425        assert_eq!(delay, Duration::from_secs(1));
426    }
427
428    #[test]
429    fn test_retry_delay_exponential_backoff_attempt_1() {
430        let err = ClientError::ServiceUnavailable("down".to_string());
431        let delay = retry_delay(&err, 1);
432        assert_eq!(delay, Duration::from_secs(2));
433    }
434
435    #[test]
436    fn test_retry_delay_exponential_backoff_attempt_2() {
437        let err = ClientError::ServiceUnavailable("down".to_string());
438        let delay = retry_delay(&err, 2);
439        assert_eq!(delay, Duration::from_secs(4));
440    }
441
442    #[test]
443    fn test_retry_delay_rate_limited_uses_retry_after() {
444        let err = ClientError::RateLimited { retry_after: 30 };
445        // retry_after should be used regardless of attempt number
446        assert_eq!(retry_delay(&err, 0), Duration::from_secs(30));
447        assert_eq!(retry_delay(&err, 1), Duration::from_secs(30));
448        assert_eq!(retry_delay(&err, 2), Duration::from_secs(30));
449    }
450
451    #[test]
452    fn test_retry_delay_rate_limited_default_retry_after() {
453        let err = ClientError::RateLimited { retry_after: 60 };
454        let delay = retry_delay(&err, 0);
455        assert_eq!(delay, Duration::from_secs(60));
456    }
457
458    #[test]
459    fn test_retry_constants() {
460        assert_eq!(MAX_RETRIES, 3);
461        assert_eq!(INITIAL_BACKOFF_SECS, 1);
462    }
463
464    #[test]
465    fn test_retry_delay_backoff_sequence() {
466        let err = ClientError::ServiceUnavailable("temporarily unavailable".to_string());
467        let delays: Vec<Duration> = (0..MAX_RETRIES).map(|i| retry_delay(&err, i)).collect();
468        assert_eq!(
469            delays,
470            vec![
471                Duration::from_secs(1),
472                Duration::from_secs(2),
473                Duration::from_secs(4),
474            ]
475        );
476    }
477
478    #[test]
479    fn test_non_retryable_error_still_computes_delay() {
480        // retry_delay computes a delay regardless; the caller decides whether to retry.
481        // This verifies the function doesn't panic on non-retryable errors.
482        let err = ClientError::Api {
483            status: 400,
484            message: "bad request".to_string(),
485        };
486        let delay = retry_delay(&err, 0);
487        assert_eq!(delay, Duration::from_secs(1));
488    }
489}