Skip to main content

hoist_client/
client.rs

1//! Azure Search REST API client
2
3use std::time::Duration;
4
5use reqwest::{Client, Method, StatusCode};
6use serde_json::Value;
7use tracing::{debug, instrument, warn};
8
9use hoist_core::resources::ResourceKind;
10use hoist_core::Config;
11
12use crate::auth::{get_auth_provider, AuthProvider};
13use crate::error::ClientError;
14
15/// Maximum number of retry attempts for retryable errors
16const MAX_RETRIES: u32 = 3;
17
18/// Initial backoff delay in seconds
19const INITIAL_BACKOFF_SECS: u64 = 1;
20
21/// Calculate the backoff duration for a given retry attempt.
22///
23/// For `RateLimited` errors with a `retry_after` value, that value is used directly.
24/// For other retryable errors, exponential backoff is applied: 1s, 2s, 4s, etc.
25fn retry_delay(error: &ClientError, attempt: u32) -> Duration {
26    match error {
27        ClientError::RateLimited { retry_after } => Duration::from_secs(*retry_after),
28        _ => Duration::from_secs(INITIAL_BACKOFF_SECS * 2u64.pow(attempt)),
29    }
30}
31
32/// Azure Search API client
33pub struct AzureSearchClient {
34    http: Client,
35    auth: Box<dyn AuthProvider>,
36    base_url: String,
37    preview_api_version: String,
38}
39
40impl AzureSearchClient {
41    /// Create a new client from configuration
42    pub fn new(config: &Config) -> Result<Self, ClientError> {
43        let auth = get_auth_provider()?;
44        let http = Client::builder()
45            .timeout(std::time::Duration::from_secs(30))
46            .build()?;
47
48        Ok(Self {
49            http,
50            auth,
51            base_url: config.service_url(),
52            preview_api_version: config.api_version_for(true).to_string(),
53        })
54    }
55
56    /// Create a client pointing to a different server, using the same auth and API versions
57    pub fn new_for_server(config: &Config, server_name: &str) -> Result<Self, ClientError> {
58        let auth = get_auth_provider()?;
59        let http = Client::builder()
60            .timeout(std::time::Duration::from_secs(30))
61            .build()?;
62
63        Ok(Self {
64            http,
65            auth,
66            base_url: format!("https://{}.search.windows.net", server_name),
67            preview_api_version: config.api_version_for(true).to_string(),
68        })
69    }
70
71    /// Create with a custom auth provider (for testing)
72    pub fn with_auth(
73        base_url: String,
74        preview_api_version: String,
75        auth: Box<dyn AuthProvider>,
76    ) -> Result<Self, ClientError> {
77        let http = Client::builder()
78            .timeout(std::time::Duration::from_secs(30))
79            .build()?;
80
81        Ok(Self {
82            http,
83            auth,
84            base_url,
85            preview_api_version,
86        })
87    }
88
89    /// Get the API version to use for a resource kind.
90    /// Always uses the preview API version — it is a superset of the stable API
91    /// and avoids failures when stable resources contain preview-only features
92    /// (e.g. a skillset with ChatCompletionSkill).
93    fn api_version_for(&self, _kind: ResourceKind) -> &str {
94        &self.preview_api_version
95    }
96
97    /// Build URL for a resource collection
98    fn collection_url(&self, kind: ResourceKind) -> String {
99        format!(
100            "{}/{}?api-version={}",
101            self.base_url,
102            kind.api_path(),
103            self.api_version_for(kind)
104        )
105    }
106
107    /// Build URL for a specific resource
108    fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
109        format!(
110            "{}/{}/{}?api-version={}",
111            self.base_url,
112            kind.api_path(),
113            name,
114            self.api_version_for(kind)
115        )
116    }
117
118    /// Execute an HTTP request
119    async fn request(
120        &self,
121        method: Method,
122        url: &str,
123        body: Option<&Value>,
124    ) -> Result<Option<Value>, ClientError> {
125        let token = self.auth.get_token()?;
126
127        let mut request = self
128            .http
129            .request(method.clone(), url)
130            .header("Authorization", format!("Bearer {}", token))
131            .header("Content-Type", "application/json");
132
133        if let Some(json) = body {
134            request = request.json(json);
135        }
136
137        debug!("Request: {} {}", method, url);
138        let response = request.send().await?;
139        let status = response.status();
140
141        if status == StatusCode::NO_CONTENT {
142            return Ok(None);
143        }
144
145        let body = response.text().await?;
146
147        if status.is_success() {
148            if body.is_empty() {
149                Ok(None)
150            } else {
151                let value: Value = serde_json::from_str(&body)?;
152                Ok(Some(value))
153            }
154        } else {
155            match status {
156                StatusCode::NOT_FOUND => Err(ClientError::NotFound {
157                    kind: "resource".to_string(),
158                    name: url.to_string(),
159                }),
160                StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
161                    kind: "resource".to_string(),
162                    name: url.to_string(),
163                }),
164                StatusCode::TOO_MANY_REQUESTS => {
165                    let retry_after = 60; // Default retry time
166                    Err(ClientError::RateLimited { retry_after })
167                }
168                StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
169                _ => Err(ClientError::from_response_with_url(
170                    status.as_u16(),
171                    &body,
172                    Some(url),
173                )),
174            }
175        }
176    }
177
178    /// Execute an HTTP request with retry logic for transient errors.
179    ///
180    /// Retries up to [`MAX_RETRIES`] times for retryable errors (429 and 503).
181    /// Uses exponential backoff (1s, 2s, 4s) for 503 errors and respects the
182    /// `retry_after` value for 429 rate-limiting errors.
183    async fn request_with_retry(
184        &self,
185        method: Method,
186        url: &str,
187        body: Option<&Value>,
188    ) -> Result<Option<Value>, ClientError> {
189        let mut attempt = 0u32;
190        loop {
191            match self.request(method.clone(), url, body).await {
192                Ok(value) => return Ok(value),
193                Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
194                    let delay = retry_delay(&err, attempt);
195                    warn!(
196                        "Request {} {} failed (attempt {}/{}): {}. Retrying in {:?}",
197                        method,
198                        url,
199                        attempt + 1,
200                        MAX_RETRIES + 1,
201                        err,
202                        delay,
203                    );
204                    tokio::time::sleep(delay).await;
205                    attempt += 1;
206                }
207                Err(err) => return Err(err),
208            }
209        }
210    }
211
212    /// List all resources of a given kind
213    #[instrument(skip(self))]
214    pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
215        let url = self.collection_url(kind);
216        let response = self.request_with_retry(Method::GET, &url, None).await?;
217
218        match response {
219            Some(value) => {
220                // Azure returns { "value": [...] }
221                let items = value
222                    .get("value")
223                    .and_then(|v| v.as_array())
224                    .cloned()
225                    .unwrap_or_default();
226                Ok(items)
227            }
228            None => Ok(Vec::new()),
229        }
230    }
231
232    /// Get a specific resource
233    #[instrument(skip(self))]
234    pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
235        let url = self.resource_url(kind, name);
236        let response = self.request_with_retry(Method::GET, &url, None).await?;
237
238        response.ok_or_else(|| ClientError::NotFound {
239            kind: kind.display_name().to_string(),
240            name: name.to_string(),
241        })
242    }
243
244    /// Create or update a resource
245    ///
246    /// Returns the response body if the API returns one. Some APIs (especially
247    /// preview endpoints like Knowledge Sources) return 204 No Content on
248    /// successful update, which yields `Ok(None)`.
249    #[instrument(skip(self, definition))]
250    pub async fn create_or_update(
251        &self,
252        kind: ResourceKind,
253        name: &str,
254        definition: &Value,
255    ) -> Result<Option<Value>, ClientError> {
256        let url = self.resource_url(kind, name);
257        self.request_with_retry(Method::PUT, &url, Some(definition))
258            .await
259    }
260
261    /// Delete a resource
262    #[instrument(skip(self))]
263    pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
264        let url = self.resource_url(kind, name);
265        self.request_with_retry(Method::DELETE, &url, None).await?;
266        Ok(())
267    }
268
269    /// Check if a resource exists
270    pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
271        match self.get(kind, name).await {
272            Ok(_) => Ok(true),
273            Err(ClientError::NotFound { .. }) => Ok(false),
274            Err(e) => Err(e),
275        }
276    }
277
278    /// Get the authentication method being used
279    pub fn auth_method(&self) -> &'static str {
280        self.auth.method_name()
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287    use crate::auth::{AuthError, AuthProvider};
288
289    struct FakeAuth;
290    impl AuthProvider for FakeAuth {
291        fn get_token(&self) -> Result<String, AuthError> {
292            Ok("fake-token".to_string())
293        }
294        fn method_name(&self) -> &'static str {
295            "Fake"
296        }
297    }
298
299    fn make_client() -> AzureSearchClient {
300        AzureSearchClient::with_auth(
301            "https://test-svc.search.windows.net".to_string(),
302            "2025-11-01-preview".to_string(),
303            Box::new(FakeAuth),
304        )
305        .unwrap()
306    }
307
308    #[test]
309    fn test_collection_url_uses_preview_version() {
310        let client = make_client();
311        let url = client.collection_url(ResourceKind::Index);
312        assert_eq!(
313            url,
314            "https://test-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
315        );
316    }
317
318    #[test]
319    fn test_collection_url_preview_resource_uses_preview_version() {
320        let client = make_client();
321        let url = client.collection_url(ResourceKind::KnowledgeBase);
322        assert_eq!(
323            url,
324            "https://test-svc.search.windows.net/knowledgebases?api-version=2025-11-01-preview"
325        );
326    }
327
328    #[test]
329    fn test_collection_url_knowledge_source_uses_preview_version() {
330        let client = make_client();
331        let url = client.collection_url(ResourceKind::KnowledgeSource);
332        assert_eq!(
333            url,
334            "https://test-svc.search.windows.net/knowledgesources?api-version=2025-11-01-preview"
335        );
336    }
337
338    #[test]
339    fn test_resource_url_uses_preview_version() {
340        let client = make_client();
341        let url = client.resource_url(ResourceKind::Index, "my-index");
342        assert_eq!(
343            url,
344            "https://test-svc.search.windows.net/indexes/my-index?api-version=2025-11-01-preview"
345        );
346    }
347
348    #[test]
349    fn test_resource_url_preview() {
350        let client = make_client();
351        let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
352        assert_eq!(
353            url,
354            "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-11-01-preview"
355        );
356    }
357
358    #[test]
359    fn test_new_for_server_produces_correct_base_url() {
360        // We can't easily test new_for_server directly since it calls get_auth_provider,
361        // but we can verify the URL format through with_auth
362        let client = AzureSearchClient::with_auth(
363            "https://other-svc.search.windows.net".to_string(),
364            "2025-11-01-preview".to_string(),
365            Box::new(FakeAuth),
366        )
367        .unwrap();
368        let url = client.collection_url(ResourceKind::Index);
369        assert_eq!(
370            url,
371            "https://other-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
372        );
373    }
374
375    #[test]
376    fn test_all_kinds_use_preview_version() {
377        let client = make_client();
378        for kind in ResourceKind::all() {
379            let url = client.collection_url(*kind);
380            assert!(
381                url.contains("2025-11-01-preview"),
382                "{:?} should use preview API version, got: {}",
383                kind,
384                url
385            );
386        }
387    }
388
389    #[test]
390    fn test_retry_delay_exponential_backoff_attempt_0() {
391        let err = ClientError::ServiceUnavailable("down".to_string());
392        let delay = retry_delay(&err, 0);
393        assert_eq!(delay, Duration::from_secs(1));
394    }
395
396    #[test]
397    fn test_retry_delay_exponential_backoff_attempt_1() {
398        let err = ClientError::ServiceUnavailable("down".to_string());
399        let delay = retry_delay(&err, 1);
400        assert_eq!(delay, Duration::from_secs(2));
401    }
402
403    #[test]
404    fn test_retry_delay_exponential_backoff_attempt_2() {
405        let err = ClientError::ServiceUnavailable("down".to_string());
406        let delay = retry_delay(&err, 2);
407        assert_eq!(delay, Duration::from_secs(4));
408    }
409
410    #[test]
411    fn test_retry_delay_rate_limited_uses_retry_after() {
412        let err = ClientError::RateLimited { retry_after: 30 };
413        // retry_after should be used regardless of attempt number
414        assert_eq!(retry_delay(&err, 0), Duration::from_secs(30));
415        assert_eq!(retry_delay(&err, 1), Duration::from_secs(30));
416        assert_eq!(retry_delay(&err, 2), Duration::from_secs(30));
417    }
418
419    #[test]
420    fn test_retry_delay_rate_limited_default_retry_after() {
421        let err = ClientError::RateLimited { retry_after: 60 };
422        let delay = retry_delay(&err, 0);
423        assert_eq!(delay, Duration::from_secs(60));
424    }
425
426    #[test]
427    fn test_retry_constants() {
428        assert_eq!(MAX_RETRIES, 3);
429        assert_eq!(INITIAL_BACKOFF_SECS, 1);
430    }
431
432    #[test]
433    fn test_retry_delay_backoff_sequence() {
434        let err = ClientError::ServiceUnavailable("temporarily unavailable".to_string());
435        let delays: Vec<Duration> = (0..MAX_RETRIES).map(|i| retry_delay(&err, i)).collect();
436        assert_eq!(
437            delays,
438            vec![
439                Duration::from_secs(1),
440                Duration::from_secs(2),
441                Duration::from_secs(4),
442            ]
443        );
444    }
445
446    #[test]
447    fn test_non_retryable_error_still_computes_delay() {
448        // retry_delay computes a delay regardless; the caller decides whether to retry.
449        // This verifies the function doesn't panic on non-retryable errors.
450        let err = ClientError::Api {
451            status: 400,
452            message: "bad request".to_string(),
453        };
454        let delay = retry_delay(&err, 0);
455        assert_eq!(delay, Duration::from_secs(1));
456    }
457}