Skip to main content

hoist_client/
client.rs

1//! Azure Search REST API client
2
3use std::time::Duration;
4
5use reqwest::{Client, Method, StatusCode};
6use serde_json::Value;
7use tracing::{debug, instrument, warn};
8
9use hoist_core::resources::ResourceKind;
10use hoist_core::Config;
11
12use crate::auth::{get_auth_provider, AuthProvider};
13use crate::error::ClientError;
14
15/// Maximum number of retry attempts for retryable errors
16const MAX_RETRIES: u32 = 3;
17
18/// Initial backoff delay in seconds
19const INITIAL_BACKOFF_SECS: u64 = 1;
20
21/// Calculate the backoff duration for a given retry attempt.
22///
23/// For `RateLimited` errors with a `retry_after` value, that value is used directly.
24/// For other retryable errors, exponential backoff is applied: 1s, 2s, 4s, etc.
25fn retry_delay(error: &ClientError, attempt: u32) -> Duration {
26    match error {
27        ClientError::RateLimited { retry_after } => Duration::from_secs(*retry_after),
28        _ => Duration::from_secs(INITIAL_BACKOFF_SECS * 2u64.pow(attempt)),
29    }
30}
31
32/// Azure Search API client
33pub struct AzureSearchClient {
34    http: Client,
35    auth: Box<dyn AuthProvider>,
36    base_url: String,
37    api_version: String,
38    preview_api_version: String,
39}
40
41impl AzureSearchClient {
42    /// Create a new client from configuration
43    pub fn new(config: &Config) -> Result<Self, ClientError> {
44        let auth = get_auth_provider()?;
45        let http = Client::builder()
46            .timeout(std::time::Duration::from_secs(30))
47            .build()?;
48
49        Ok(Self {
50            http,
51            auth,
52            base_url: config.service_url(),
53            api_version: config.api_version_for(false).to_string(),
54            preview_api_version: config.api_version_for(true).to_string(),
55        })
56    }
57
58    /// Create a client pointing to a different server, using the same auth and API versions
59    pub fn new_for_server(config: &Config, server_name: &str) -> Result<Self, ClientError> {
60        let auth = get_auth_provider()?;
61        let http = Client::builder()
62            .timeout(std::time::Duration::from_secs(30))
63            .build()?;
64
65        Ok(Self {
66            http,
67            auth,
68            base_url: format!("https://{}.search.windows.net", server_name),
69            api_version: config.api_version_for(false).to_string(),
70            preview_api_version: config.api_version_for(true).to_string(),
71        })
72    }
73
74    /// Create with a custom auth provider (for testing)
75    pub fn with_auth(
76        base_url: String,
77        api_version: String,
78        preview_api_version: String,
79        auth: Box<dyn AuthProvider>,
80    ) -> Result<Self, ClientError> {
81        let http = Client::builder()
82            .timeout(std::time::Duration::from_secs(30))
83            .build()?;
84
85        Ok(Self {
86            http,
87            auth,
88            base_url,
89            api_version,
90            preview_api_version,
91        })
92    }
93
94    /// Get the API version to use for a resource kind
95    fn api_version_for(&self, kind: ResourceKind) -> &str {
96        if kind.is_preview() {
97            &self.preview_api_version
98        } else {
99            &self.api_version
100        }
101    }
102
103    /// Build URL for a resource collection
104    fn collection_url(&self, kind: ResourceKind) -> String {
105        format!(
106            "{}/{}?api-version={}",
107            self.base_url,
108            kind.api_path(),
109            self.api_version_for(kind)
110        )
111    }
112
113    /// Build URL for a specific resource
114    fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
115        format!(
116            "{}/{}/{}?api-version={}",
117            self.base_url,
118            kind.api_path(),
119            name,
120            self.api_version_for(kind)
121        )
122    }
123
124    /// Execute an HTTP request
125    async fn request(
126        &self,
127        method: Method,
128        url: &str,
129        body: Option<&Value>,
130    ) -> Result<Option<Value>, ClientError> {
131        let token = self.auth.get_token()?;
132
133        let mut request = self
134            .http
135            .request(method.clone(), url)
136            .header("Authorization", format!("Bearer {}", token))
137            .header("Content-Type", "application/json");
138
139        if let Some(json) = body {
140            request = request.json(json);
141        }
142
143        debug!("Request: {} {}", method, url);
144        let response = request.send().await?;
145        let status = response.status();
146
147        if status == StatusCode::NO_CONTENT {
148            return Ok(None);
149        }
150
151        let body = response.text().await?;
152
153        if status.is_success() {
154            if body.is_empty() {
155                Ok(None)
156            } else {
157                let value: Value = serde_json::from_str(&body)?;
158                Ok(Some(value))
159            }
160        } else {
161            match status {
162                StatusCode::NOT_FOUND => Err(ClientError::NotFound {
163                    kind: "resource".to_string(),
164                    name: url.to_string(),
165                }),
166                StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
167                    kind: "resource".to_string(),
168                    name: url.to_string(),
169                }),
170                StatusCode::TOO_MANY_REQUESTS => {
171                    let retry_after = 60; // Default retry time
172                    Err(ClientError::RateLimited { retry_after })
173                }
174                StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
175                _ => Err(ClientError::from_response_with_url(
176                    status.as_u16(),
177                    &body,
178                    Some(url),
179                )),
180            }
181        }
182    }
183
184    /// Execute an HTTP request with retry logic for transient errors.
185    ///
186    /// Retries up to [`MAX_RETRIES`] times for retryable errors (429 and 503).
187    /// Uses exponential backoff (1s, 2s, 4s) for 503 errors and respects the
188    /// `retry_after` value for 429 rate-limiting errors.
189    async fn request_with_retry(
190        &self,
191        method: Method,
192        url: &str,
193        body: Option<&Value>,
194    ) -> Result<Option<Value>, ClientError> {
195        let mut attempt = 0u32;
196        loop {
197            match self.request(method.clone(), url, body).await {
198                Ok(value) => return Ok(value),
199                Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
200                    let delay = retry_delay(&err, attempt);
201                    warn!(
202                        "Request {} {} failed (attempt {}/{}): {}. Retrying in {:?}",
203                        method,
204                        url,
205                        attempt + 1,
206                        MAX_RETRIES + 1,
207                        err,
208                        delay,
209                    );
210                    tokio::time::sleep(delay).await;
211                    attempt += 1;
212                }
213                Err(err) => return Err(err),
214            }
215        }
216    }
217
218    /// List all resources of a given kind
219    #[instrument(skip(self))]
220    pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
221        let url = self.collection_url(kind);
222        let response = self.request_with_retry(Method::GET, &url, None).await?;
223
224        match response {
225            Some(value) => {
226                // Azure returns { "value": [...] }
227                let items = value
228                    .get("value")
229                    .and_then(|v| v.as_array())
230                    .cloned()
231                    .unwrap_or_default();
232                Ok(items)
233            }
234            None => Ok(Vec::new()),
235        }
236    }
237
238    /// Get a specific resource
239    #[instrument(skip(self))]
240    pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
241        let url = self.resource_url(kind, name);
242        let response = self.request_with_retry(Method::GET, &url, None).await?;
243
244        response.ok_or_else(|| ClientError::NotFound {
245            kind: kind.display_name().to_string(),
246            name: name.to_string(),
247        })
248    }
249
250    /// Create or update a resource
251    ///
252    /// Returns the response body if the API returns one. Some APIs (especially
253    /// preview endpoints like Knowledge Sources) return 204 No Content on
254    /// successful update, which yields `Ok(None)`.
255    #[instrument(skip(self, definition))]
256    pub async fn create_or_update(
257        &self,
258        kind: ResourceKind,
259        name: &str,
260        definition: &Value,
261    ) -> Result<Option<Value>, ClientError> {
262        let url = self.resource_url(kind, name);
263        self.request_with_retry(Method::PUT, &url, Some(definition))
264            .await
265    }
266
267    /// Delete a resource
268    #[instrument(skip(self))]
269    pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
270        let url = self.resource_url(kind, name);
271        self.request_with_retry(Method::DELETE, &url, None).await?;
272        Ok(())
273    }
274
275    /// Check if a resource exists
276    pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
277        match self.get(kind, name).await {
278            Ok(_) => Ok(true),
279            Err(ClientError::NotFound { .. }) => Ok(false),
280            Err(e) => Err(e),
281        }
282    }
283
284    /// Get the authentication method being used
285    pub fn auth_method(&self) -> &'static str {
286        self.auth.method_name()
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use crate::auth::{AuthError, AuthProvider};
294
295    struct FakeAuth;
296    impl AuthProvider for FakeAuth {
297        fn get_token(&self) -> Result<String, AuthError> {
298            Ok("fake-token".to_string())
299        }
300        fn method_name(&self) -> &'static str {
301            "Fake"
302        }
303    }
304
305    fn make_client() -> AzureSearchClient {
306        AzureSearchClient::with_auth(
307            "https://test-svc.search.windows.net".to_string(),
308            "2024-07-01".to_string(),
309            "2025-11-01-preview".to_string(),
310            Box::new(FakeAuth),
311        )
312        .unwrap()
313    }
314
315    #[test]
316    fn test_collection_url_stable_resource() {
317        let client = make_client();
318        let url = client.collection_url(ResourceKind::Index);
319        assert_eq!(
320            url,
321            "https://test-svc.search.windows.net/indexes?api-version=2024-07-01"
322        );
323    }
324
325    #[test]
326    fn test_collection_url_preview_resource_uses_preview_version() {
327        let client = make_client();
328        let url = client.collection_url(ResourceKind::KnowledgeBase);
329        assert_eq!(
330            url,
331            "https://test-svc.search.windows.net/knowledgebases?api-version=2025-11-01-preview"
332        );
333    }
334
335    #[test]
336    fn test_collection_url_knowledge_source_uses_preview_version() {
337        let client = make_client();
338        let url = client.collection_url(ResourceKind::KnowledgeSource);
339        assert_eq!(
340            url,
341            "https://test-svc.search.windows.net/knowledgesources?api-version=2025-11-01-preview"
342        );
343    }
344
345    #[test]
346    fn test_resource_url_stable() {
347        let client = make_client();
348        let url = client.resource_url(ResourceKind::Index, "my-index");
349        assert_eq!(
350            url,
351            "https://test-svc.search.windows.net/indexes/my-index?api-version=2024-07-01"
352        );
353    }
354
355    #[test]
356    fn test_resource_url_preview() {
357        let client = make_client();
358        let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
359        assert_eq!(
360            url,
361            "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-11-01-preview"
362        );
363    }
364
365    #[test]
366    fn test_all_stable_kinds_use_stable_version() {
367        let client = make_client();
368        for kind in ResourceKind::stable() {
369            let url = client.collection_url(*kind);
370            assert!(
371                url.contains("2024-07-01"),
372                "{:?} should use stable API version, got: {}",
373                kind,
374                url
375            );
376        }
377    }
378
379    #[test]
380    fn test_new_for_server_produces_correct_base_url() {
381        // We can't easily test new_for_server directly since it calls get_auth_provider,
382        // but we can verify the URL format through with_auth
383        let client = AzureSearchClient::with_auth(
384            "https://other-svc.search.windows.net".to_string(),
385            "2024-07-01".to_string(),
386            "2025-11-01-preview".to_string(),
387            Box::new(FakeAuth),
388        )
389        .unwrap();
390        let url = client.collection_url(ResourceKind::Index);
391        assert_eq!(
392            url,
393            "https://other-svc.search.windows.net/indexes?api-version=2024-07-01"
394        );
395    }
396
397    #[test]
398    fn test_all_preview_kinds_use_preview_version() {
399        let client = make_client();
400        for kind in ResourceKind::all() {
401            if kind.is_preview() {
402                let url = client.collection_url(*kind);
403                assert!(
404                    url.contains("2025-11-01-preview"),
405                    "{:?} should use preview API version, got: {}",
406                    kind,
407                    url
408                );
409            }
410        }
411    }
412
413    #[test]
414    fn test_retry_delay_exponential_backoff_attempt_0() {
415        let err = ClientError::ServiceUnavailable("down".to_string());
416        let delay = retry_delay(&err, 0);
417        assert_eq!(delay, Duration::from_secs(1));
418    }
419
420    #[test]
421    fn test_retry_delay_exponential_backoff_attempt_1() {
422        let err = ClientError::ServiceUnavailable("down".to_string());
423        let delay = retry_delay(&err, 1);
424        assert_eq!(delay, Duration::from_secs(2));
425    }
426
427    #[test]
428    fn test_retry_delay_exponential_backoff_attempt_2() {
429        let err = ClientError::ServiceUnavailable("down".to_string());
430        let delay = retry_delay(&err, 2);
431        assert_eq!(delay, Duration::from_secs(4));
432    }
433
434    #[test]
435    fn test_retry_delay_rate_limited_uses_retry_after() {
436        let err = ClientError::RateLimited { retry_after: 30 };
437        // retry_after should be used regardless of attempt number
438        assert_eq!(retry_delay(&err, 0), Duration::from_secs(30));
439        assert_eq!(retry_delay(&err, 1), Duration::from_secs(30));
440        assert_eq!(retry_delay(&err, 2), Duration::from_secs(30));
441    }
442
443    #[test]
444    fn test_retry_delay_rate_limited_default_retry_after() {
445        let err = ClientError::RateLimited { retry_after: 60 };
446        let delay = retry_delay(&err, 0);
447        assert_eq!(delay, Duration::from_secs(60));
448    }
449
450    #[test]
451    fn test_retry_constants() {
452        assert_eq!(MAX_RETRIES, 3);
453        assert_eq!(INITIAL_BACKOFF_SECS, 1);
454    }
455
456    #[test]
457    fn test_retry_delay_backoff_sequence() {
458        let err = ClientError::ServiceUnavailable("temporarily unavailable".to_string());
459        let delays: Vec<Duration> = (0..MAX_RETRIES).map(|i| retry_delay(&err, i)).collect();
460        assert_eq!(
461            delays,
462            vec![
463                Duration::from_secs(1),
464                Duration::from_secs(2),
465                Duration::from_secs(4),
466            ]
467        );
468    }
469
470    #[test]
471    fn test_non_retryable_error_still_computes_delay() {
472        // retry_delay computes a delay regardless; the caller decides whether to retry.
473        // This verifies the function doesn't panic on non-retryable errors.
474        let err = ClientError::Api {
475            status: 400,
476            message: "bad request".to_string(),
477        };
478        let delay = retry_delay(&err, 0);
479        assert_eq!(delay, Duration::from_secs(1));
480    }
481}