Skip to main content

hoist_client/
client.rs

1//! Azure Search REST API client
2
3use std::time::Duration;
4
5use reqwest::{Client, Method, StatusCode};
6use serde_json::Value;
7use tracing::{debug, instrument, warn};
8
9use hoist_core::resources::ResourceKind;
10use hoist_core::Config;
11
12use crate::auth::{get_auth_provider, AuthProvider};
13use crate::error::ClientError;
14
15/// Maximum number of retry attempts for retryable errors
16const MAX_RETRIES: u32 = 3;
17
18/// Initial backoff delay in seconds
19const INITIAL_BACKOFF_SECS: u64 = 1;
20
21/// Calculate the backoff duration for a given retry attempt.
22///
23/// For `RateLimited` errors with a `retry_after` value, that value is used directly.
24/// For other retryable errors, exponential backoff is applied: 1s, 2s, 4s, etc.
25fn retry_delay(error: &ClientError, attempt: u32) -> Duration {
26    match error {
27        ClientError::RateLimited { retry_after } => Duration::from_secs(*retry_after),
28        _ => Duration::from_secs(INITIAL_BACKOFF_SECS * 2u64.pow(attempt)),
29    }
30}
31
32/// Azure Search API client
33pub struct AzureSearchClient {
34    http: Client,
35    auth: Box<dyn AuthProvider>,
36    base_url: String,
37    api_version: String,
38    preview_api_version: String,
39}
40
41impl AzureSearchClient {
42    /// Create a new client from configuration
43    pub fn new(config: &Config) -> Result<Self, ClientError> {
44        let auth = get_auth_provider()?;
45        let http = Client::builder()
46            .timeout(std::time::Duration::from_secs(30))
47            .build()?;
48
49        Ok(Self {
50            http,
51            auth,
52            base_url: config.service_url(),
53            api_version: config.api_version_for(false).to_string(),
54            preview_api_version: config.api_version_for(true).to_string(),
55        })
56    }
57
58    /// Create a client pointing to a different server, using the same auth and API versions
59    pub fn new_for_server(config: &Config, server_name: &str) -> Result<Self, ClientError> {
60        let auth = get_auth_provider()?;
61        let http = Client::builder()
62            .timeout(std::time::Duration::from_secs(30))
63            .build()?;
64
65        Ok(Self {
66            http,
67            auth,
68            base_url: format!("https://{}.search.windows.net", server_name),
69            api_version: config.api_version_for(false).to_string(),
70            preview_api_version: config.api_version_for(true).to_string(),
71        })
72    }
73
74    /// Create with a custom auth provider (for testing)
75    pub fn with_auth(
76        base_url: String,
77        api_version: String,
78        preview_api_version: String,
79        auth: Box<dyn AuthProvider>,
80    ) -> Result<Self, ClientError> {
81        let http = Client::builder()
82            .timeout(std::time::Duration::from_secs(30))
83            .build()?;
84
85        Ok(Self {
86            http,
87            auth,
88            base_url,
89            api_version,
90            preview_api_version,
91        })
92    }
93
94    /// Get the API version to use for a resource kind
95    fn api_version_for(&self, kind: ResourceKind) -> &str {
96        if kind.is_preview() {
97            &self.preview_api_version
98        } else {
99            &self.api_version
100        }
101    }
102
103    /// Build URL for a resource collection
104    fn collection_url(&self, kind: ResourceKind) -> String {
105        format!(
106            "{}/{}?api-version={}",
107            self.base_url,
108            kind.api_path(),
109            self.api_version_for(kind)
110        )
111    }
112
113    /// Build URL for a specific resource
114    fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
115        format!(
116            "{}/{}/{}?api-version={}",
117            self.base_url,
118            kind.api_path(),
119            name,
120            self.api_version_for(kind)
121        )
122    }
123
124    /// Execute an HTTP request
125    async fn request(
126        &self,
127        method: Method,
128        url: &str,
129        body: Option<&Value>,
130    ) -> Result<Option<Value>, ClientError> {
131        let token = self.auth.get_token()?;
132
133        let mut request = self
134            .http
135            .request(method.clone(), url)
136            .header("Authorization", format!("Bearer {}", token))
137            .header("Content-Type", "application/json");
138
139        if let Some(json) = body {
140            request = request.json(json);
141        }
142
143        debug!("Request: {} {}", method, url);
144        let response = request.send().await?;
145        let status = response.status();
146
147        if status == StatusCode::NO_CONTENT {
148            return Ok(None);
149        }
150
151        let body = response.text().await?;
152
153        if status.is_success() {
154            if body.is_empty() {
155                Ok(None)
156            } else {
157                let value: Value = serde_json::from_str(&body)?;
158                Ok(Some(value))
159            }
160        } else {
161            match status {
162                StatusCode::NOT_FOUND => Err(ClientError::NotFound {
163                    kind: "resource".to_string(),
164                    name: url.to_string(),
165                }),
166                StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
167                    kind: "resource".to_string(),
168                    name: url.to_string(),
169                }),
170                StatusCode::TOO_MANY_REQUESTS => {
171                    let retry_after = 60; // Default retry time
172                    Err(ClientError::RateLimited { retry_after })
173                }
174                StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
175                _ => Err(ClientError::from_response(status.as_u16(), &body)),
176            }
177        }
178    }
179
180    /// Execute an HTTP request with retry logic for transient errors.
181    ///
182    /// Retries up to [`MAX_RETRIES`] times for retryable errors (429 and 503).
183    /// Uses exponential backoff (1s, 2s, 4s) for 503 errors and respects the
184    /// `retry_after` value for 429 rate-limiting errors.
185    async fn request_with_retry(
186        &self,
187        method: Method,
188        url: &str,
189        body: Option<&Value>,
190    ) -> Result<Option<Value>, ClientError> {
191        let mut attempt = 0u32;
192        loop {
193            match self.request(method.clone(), url, body).await {
194                Ok(value) => return Ok(value),
195                Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
196                    let delay = retry_delay(&err, attempt);
197                    warn!(
198                        "Request {} {} failed (attempt {}/{}): {}. Retrying in {:?}",
199                        method,
200                        url,
201                        attempt + 1,
202                        MAX_RETRIES + 1,
203                        err,
204                        delay,
205                    );
206                    tokio::time::sleep(delay).await;
207                    attempt += 1;
208                }
209                Err(err) => return Err(err),
210            }
211        }
212    }
213
214    /// List all resources of a given kind
215    #[instrument(skip(self))]
216    pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
217        let url = self.collection_url(kind);
218        let response = self.request_with_retry(Method::GET, &url, None).await?;
219
220        match response {
221            Some(value) => {
222                // Azure returns { "value": [...] }
223                let items = value
224                    .get("value")
225                    .and_then(|v| v.as_array())
226                    .cloned()
227                    .unwrap_or_default();
228                Ok(items)
229            }
230            None => Ok(Vec::new()),
231        }
232    }
233
234    /// Get a specific resource
235    #[instrument(skip(self))]
236    pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
237        let url = self.resource_url(kind, name);
238        let response = self.request_with_retry(Method::GET, &url, None).await?;
239
240        response.ok_or_else(|| ClientError::NotFound {
241            kind: kind.display_name().to_string(),
242            name: name.to_string(),
243        })
244    }
245
246    /// Create or update a resource
247    ///
248    /// Returns the response body if the API returns one. Some APIs (especially
249    /// preview endpoints like Knowledge Sources) return 204 No Content on
250    /// successful update, which yields `Ok(None)`.
251    #[instrument(skip(self, definition))]
252    pub async fn create_or_update(
253        &self,
254        kind: ResourceKind,
255        name: &str,
256        definition: &Value,
257    ) -> Result<Option<Value>, ClientError> {
258        let url = self.resource_url(kind, name);
259        self.request_with_retry(Method::PUT, &url, Some(definition))
260            .await
261    }
262
263    /// Delete a resource
264    #[instrument(skip(self))]
265    pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
266        let url = self.resource_url(kind, name);
267        self.request_with_retry(Method::DELETE, &url, None).await?;
268        Ok(())
269    }
270
271    /// Check if a resource exists
272    pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
273        match self.get(kind, name).await {
274            Ok(_) => Ok(true),
275            Err(ClientError::NotFound { .. }) => Ok(false),
276            Err(e) => Err(e),
277        }
278    }
279
280    /// Get the authentication method being used
281    pub fn auth_method(&self) -> &'static str {
282        self.auth.method_name()
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::auth::{AuthError, AuthProvider};
290
291    struct FakeAuth;
292    impl AuthProvider for FakeAuth {
293        fn get_token(&self) -> Result<String, AuthError> {
294            Ok("fake-token".to_string())
295        }
296        fn method_name(&self) -> &'static str {
297            "Fake"
298        }
299    }
300
301    fn make_client() -> AzureSearchClient {
302        AzureSearchClient::with_auth(
303            "https://test-svc.search.windows.net".to_string(),
304            "2024-07-01".to_string(),
305            "2025-11-01-preview".to_string(),
306            Box::new(FakeAuth),
307        )
308        .unwrap()
309    }
310
311    #[test]
312    fn test_collection_url_stable_resource() {
313        let client = make_client();
314        let url = client.collection_url(ResourceKind::Index);
315        assert_eq!(
316            url,
317            "https://test-svc.search.windows.net/indexes?api-version=2024-07-01"
318        );
319    }
320
321    #[test]
322    fn test_collection_url_preview_resource_uses_preview_version() {
323        let client = make_client();
324        let url = client.collection_url(ResourceKind::KnowledgeBase);
325        assert_eq!(
326            url,
327            "https://test-svc.search.windows.net/knowledgebases?api-version=2025-11-01-preview"
328        );
329    }
330
331    #[test]
332    fn test_collection_url_knowledge_source_uses_preview_version() {
333        let client = make_client();
334        let url = client.collection_url(ResourceKind::KnowledgeSource);
335        assert_eq!(
336            url,
337            "https://test-svc.search.windows.net/knowledgesources?api-version=2025-11-01-preview"
338        );
339    }
340
341    #[test]
342    fn test_resource_url_stable() {
343        let client = make_client();
344        let url = client.resource_url(ResourceKind::Index, "my-index");
345        assert_eq!(
346            url,
347            "https://test-svc.search.windows.net/indexes/my-index?api-version=2024-07-01"
348        );
349    }
350
351    #[test]
352    fn test_resource_url_preview() {
353        let client = make_client();
354        let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
355        assert_eq!(
356            url,
357            "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-11-01-preview"
358        );
359    }
360
361    #[test]
362    fn test_all_stable_kinds_use_stable_version() {
363        let client = make_client();
364        for kind in ResourceKind::stable() {
365            let url = client.collection_url(*kind);
366            assert!(
367                url.contains("2024-07-01"),
368                "{:?} should use stable API version, got: {}",
369                kind,
370                url
371            );
372        }
373    }
374
375    #[test]
376    fn test_new_for_server_produces_correct_base_url() {
377        // We can't easily test new_for_server directly since it calls get_auth_provider,
378        // but we can verify the URL format through with_auth
379        let client = AzureSearchClient::with_auth(
380            "https://other-svc.search.windows.net".to_string(),
381            "2024-07-01".to_string(),
382            "2025-11-01-preview".to_string(),
383            Box::new(FakeAuth),
384        )
385        .unwrap();
386        let url = client.collection_url(ResourceKind::Index);
387        assert_eq!(
388            url,
389            "https://other-svc.search.windows.net/indexes?api-version=2024-07-01"
390        );
391    }
392
393    #[test]
394    fn test_all_preview_kinds_use_preview_version() {
395        let client = make_client();
396        for kind in ResourceKind::all() {
397            if kind.is_preview() {
398                let url = client.collection_url(*kind);
399                assert!(
400                    url.contains("2025-11-01-preview"),
401                    "{:?} should use preview API version, got: {}",
402                    kind,
403                    url
404                );
405            }
406        }
407    }
408
409    #[test]
410    fn test_retry_delay_exponential_backoff_attempt_0() {
411        let err = ClientError::ServiceUnavailable("down".to_string());
412        let delay = retry_delay(&err, 0);
413        assert_eq!(delay, Duration::from_secs(1));
414    }
415
416    #[test]
417    fn test_retry_delay_exponential_backoff_attempt_1() {
418        let err = ClientError::ServiceUnavailable("down".to_string());
419        let delay = retry_delay(&err, 1);
420        assert_eq!(delay, Duration::from_secs(2));
421    }
422
423    #[test]
424    fn test_retry_delay_exponential_backoff_attempt_2() {
425        let err = ClientError::ServiceUnavailable("down".to_string());
426        let delay = retry_delay(&err, 2);
427        assert_eq!(delay, Duration::from_secs(4));
428    }
429
430    #[test]
431    fn test_retry_delay_rate_limited_uses_retry_after() {
432        let err = ClientError::RateLimited { retry_after: 30 };
433        // retry_after should be used regardless of attempt number
434        assert_eq!(retry_delay(&err, 0), Duration::from_secs(30));
435        assert_eq!(retry_delay(&err, 1), Duration::from_secs(30));
436        assert_eq!(retry_delay(&err, 2), Duration::from_secs(30));
437    }
438
439    #[test]
440    fn test_retry_delay_rate_limited_default_retry_after() {
441        let err = ClientError::RateLimited { retry_after: 60 };
442        let delay = retry_delay(&err, 0);
443        assert_eq!(delay, Duration::from_secs(60));
444    }
445
446    #[test]
447    fn test_retry_constants() {
448        assert_eq!(MAX_RETRIES, 3);
449        assert_eq!(INITIAL_BACKOFF_SECS, 1);
450    }
451
452    #[test]
453    fn test_retry_delay_backoff_sequence() {
454        let err = ClientError::ServiceUnavailable("temporarily unavailable".to_string());
455        let delays: Vec<Duration> = (0..MAX_RETRIES).map(|i| retry_delay(&err, i)).collect();
456        assert_eq!(
457            delays,
458            vec![
459                Duration::from_secs(1),
460                Duration::from_secs(2),
461                Duration::from_secs(4),
462            ]
463        );
464    }
465
466    #[test]
467    fn test_non_retryable_error_still_computes_delay() {
468        // retry_delay computes a delay regardless; the caller decides whether to retry.
469        // This verifies the function doesn't panic on non-retryable errors.
470        let err = ClientError::Api {
471            status: 400,
472            message: "bad request".to_string(),
473        };
474        let delay = retry_delay(&err, 0);
475        assert_eq!(delay, Duration::from_secs(1));
476    }
477}