Skip to main content

hoist_client/
client.rs

1//! Azure Search REST API client
2
3use std::time::Duration;
4
5use reqwest::{Client, Method, StatusCode};
6use serde_json::Value;
7use tracing::{debug, instrument, warn};
8
9use hoist_core::config::SearchServiceConfig;
10use hoist_core::resources::ResourceKind;
11
12use crate::auth::{AuthProvider, get_auth_provider};
13use crate::error::ClientError;
14
15/// Maximum number of retry attempts for retryable errors
16const MAX_RETRIES: u32 = 3;
17
18/// Initial backoff delay in seconds
19const INITIAL_BACKOFF_SECS: u64 = 1;
20
21/// Calculate the backoff duration for a given retry attempt.
22///
23/// For `RateLimited` errors with a `retry_after` value, that value is used directly.
24/// For other retryable errors, exponential backoff is applied: 1s, 2s, 4s, etc.
25fn retry_delay(error: &ClientError, attempt: u32) -> Duration {
26    match error {
27        ClientError::RateLimited { retry_after } => Duration::from_secs(*retry_after),
28        _ => Duration::from_secs(INITIAL_BACKOFF_SECS * 2u64.pow(attempt)),
29    }
30}
31
32/// Azure Search API client
33pub struct AzureSearchClient {
34    http: Client,
35    auth: Box<dyn AuthProvider>,
36    base_url: String,
37    preview_api_version: String,
38}
39
40impl AzureSearchClient {
41    /// Create client from a specific search service config
42    pub fn from_service_config(service: &SearchServiceConfig) -> Result<Self, ClientError> {
43        let auth = get_auth_provider()?;
44        let http = Client::builder().timeout(Duration::from_secs(30)).build()?;
45
46        Ok(Self {
47            http,
48            auth,
49            base_url: service.service_url(),
50            preview_api_version: service.preview_api_version.clone(),
51        })
52    }
53
54    /// Create with a custom auth provider (for testing)
55    pub fn with_auth(
56        base_url: String,
57        preview_api_version: String,
58        auth: Box<dyn AuthProvider>,
59    ) -> Result<Self, ClientError> {
60        let http = Client::builder()
61            .timeout(std::time::Duration::from_secs(30))
62            .build()?;
63
64        Ok(Self {
65            http,
66            auth,
67            base_url,
68            preview_api_version,
69        })
70    }
71
72    /// Get the API version to use for a resource kind.
73    /// Always uses the preview API version — it is a superset of the stable API
74    /// and avoids failures when stable resources contain preview-only features
75    /// (e.g. a skillset with ChatCompletionSkill).
76    fn api_version_for(&self, _kind: ResourceKind) -> &str {
77        &self.preview_api_version
78    }
79
80    /// Build URL for a resource collection
81    fn collection_url(&self, kind: ResourceKind) -> String {
82        format!(
83            "{}/{}?api-version={}",
84            self.base_url,
85            kind.api_path(),
86            self.api_version_for(kind)
87        )
88    }
89
90    /// Build URL for a specific resource
91    fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
92        format!(
93            "{}/{}/{}?api-version={}",
94            self.base_url,
95            kind.api_path(),
96            urlencoding::encode(name),
97            self.api_version_for(kind)
98        )
99    }
100
101    /// Execute an HTTP request
102    async fn request(
103        &self,
104        method: Method,
105        url: &str,
106        body: Option<&Value>,
107    ) -> Result<Option<Value>, ClientError> {
108        let token = self.auth.get_token()?;
109
110        let mut request = self
111            .http
112            .request(method.clone(), url)
113            .header("Authorization", format!("Bearer {}", token))
114            .header("Content-Type", "application/json");
115
116        if let Some(json) = body {
117            request = request.json(json);
118        }
119
120        debug!("Request: {} {}", method, url);
121        let response = request.send().await?;
122        let status = response.status();
123
124        if status == StatusCode::NO_CONTENT {
125            return Ok(None);
126        }
127
128        let body = response.text().await?;
129
130        if status.is_success() {
131            if body.is_empty() {
132                Ok(None)
133            } else {
134                let value: Value = serde_json::from_str(&body)?;
135                Ok(Some(value))
136            }
137        } else {
138            match status {
139                StatusCode::NOT_FOUND => Err(ClientError::NotFound {
140                    kind: "resource".to_string(),
141                    name: url.to_string(),
142                }),
143                StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
144                    kind: "resource".to_string(),
145                    name: url.to_string(),
146                }),
147                StatusCode::TOO_MANY_REQUESTS => {
148                    let retry_after = 60; // Default retry time
149                    Err(ClientError::RateLimited { retry_after })
150                }
151                StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
152                _ => Err(ClientError::from_response_with_url(
153                    status.as_u16(),
154                    &body,
155                    Some(url),
156                )),
157            }
158        }
159    }
160
161    /// Execute an HTTP request with retry logic for transient errors.
162    ///
163    /// Retries up to [`MAX_RETRIES`] times for retryable errors (429 and 503).
164    /// Uses exponential backoff (1s, 2s, 4s) for 503 errors and respects the
165    /// `retry_after` value for 429 rate-limiting errors.
166    async fn request_with_retry(
167        &self,
168        method: Method,
169        url: &str,
170        body: Option<&Value>,
171    ) -> Result<Option<Value>, ClientError> {
172        let mut attempt = 0u32;
173        loop {
174            match self.request(method.clone(), url, body).await {
175                Ok(value) => return Ok(value),
176                Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
177                    let delay = retry_delay(&err, attempt);
178                    warn!(
179                        "Request {} {} failed (attempt {}/{}): {}. Retrying in {:?}",
180                        method,
181                        url,
182                        attempt + 1,
183                        MAX_RETRIES + 1,
184                        err,
185                        delay,
186                    );
187                    tokio::time::sleep(delay).await;
188                    attempt += 1;
189                }
190                Err(err) => return Err(err),
191            }
192        }
193    }
194
195    /// List all resources of a given kind
196    #[instrument(skip(self))]
197    pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
198        let url = self.collection_url(kind);
199        let response = self.request_with_retry(Method::GET, &url, None).await?;
200
201        match response {
202            Some(value) => {
203                // Azure returns { "value": [...] }
204                let items = value
205                    .get("value")
206                    .and_then(|v| v.as_array())
207                    .cloned()
208                    .unwrap_or_default();
209                Ok(items)
210            }
211            None => Ok(Vec::new()),
212        }
213    }
214
215    /// Get a specific resource
216    #[instrument(skip(self))]
217    pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
218        let url = self.resource_url(kind, name);
219        let response = self.request_with_retry(Method::GET, &url, None).await?;
220
221        response.ok_or_else(|| ClientError::NotFound {
222            kind: kind.display_name().to_string(),
223            name: name.to_string(),
224        })
225    }
226
227    /// Create or update a resource
228    ///
229    /// Returns the response body if the API returns one. Some APIs (especially
230    /// preview endpoints like Knowledge Sources) return 204 No Content on
231    /// successful update, which yields `Ok(None)`.
232    #[instrument(skip(self, definition))]
233    pub async fn create_or_update(
234        &self,
235        kind: ResourceKind,
236        name: &str,
237        definition: &Value,
238    ) -> Result<Option<Value>, ClientError> {
239        let url = self.resource_url(kind, name);
240        self.request_with_retry(Method::PUT, &url, Some(definition))
241            .await
242    }
243
244    /// Delete a resource
245    #[instrument(skip(self))]
246    pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
247        let url = self.resource_url(kind, name);
248        self.request_with_retry(Method::DELETE, &url, None).await?;
249        Ok(())
250    }
251
252    /// Check if a resource exists
253    pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
254        match self.get(kind, name).await {
255            Ok(_) => Ok(true),
256            Err(ClientError::NotFound { .. }) => Ok(false),
257            Err(e) => Err(e),
258        }
259    }
260
261    /// Get the authentication method being used
262    pub fn auth_method(&self) -> &'static str {
263        self.auth.method_name()
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use crate::auth::{AuthError, AuthProvider};
271
272    struct FakeAuth;
273    impl AuthProvider for FakeAuth {
274        fn get_token(&self) -> Result<String, AuthError> {
275            Ok("fake-token".to_string())
276        }
277        fn method_name(&self) -> &'static str {
278            "Fake"
279        }
280    }
281
282    fn make_client() -> AzureSearchClient {
283        AzureSearchClient::with_auth(
284            "https://test-svc.search.windows.net".to_string(),
285            "2025-11-01-preview".to_string(),
286            Box::new(FakeAuth),
287        )
288        .unwrap()
289    }
290
291    #[test]
292    fn test_collection_url_uses_preview_version() {
293        let client = make_client();
294        let url = client.collection_url(ResourceKind::Index);
295        assert_eq!(
296            url,
297            "https://test-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
298        );
299    }
300
301    #[test]
302    fn test_collection_url_knowledge_base_uses_preview_version() {
303        let client = make_client();
304        let url = client.collection_url(ResourceKind::KnowledgeBase);
305        assert_eq!(
306            url,
307            "https://test-svc.search.windows.net/knowledgebases?api-version=2025-11-01-preview"
308        );
309    }
310
311    #[test]
312    fn test_collection_url_knowledge_source_uses_preview_version() {
313        let client = make_client();
314        let url = client.collection_url(ResourceKind::KnowledgeSource);
315        assert_eq!(
316            url,
317            "https://test-svc.search.windows.net/knowledgesources?api-version=2025-11-01-preview"
318        );
319    }
320
321    #[test]
322    fn test_resource_url_uses_preview_version() {
323        let client = make_client();
324        let url = client.resource_url(ResourceKind::Index, "my-index");
325        assert_eq!(
326            url,
327            "https://test-svc.search.windows.net/indexes/my-index?api-version=2025-11-01-preview"
328        );
329    }
330
331    #[test]
332    fn test_resource_url_knowledge_base_uses_preview_version() {
333        let client = make_client();
334        let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
335        assert_eq!(
336            url,
337            "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-11-01-preview"
338        );
339    }
340
341    #[test]
342    fn test_new_for_server_produces_correct_base_url() {
343        // We can't easily test new_for_server directly since it calls get_auth_provider,
344        // but we can verify the URL format through with_auth
345        let client = AzureSearchClient::with_auth(
346            "https://other-svc.search.windows.net".to_string(),
347            "2025-11-01-preview".to_string(),
348            Box::new(FakeAuth),
349        )
350        .unwrap();
351        let url = client.collection_url(ResourceKind::Index);
352        assert_eq!(
353            url,
354            "https://other-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
355        );
356    }
357
358    #[test]
359    fn test_all_kinds_use_preview_version() {
360        let client = make_client();
361        for kind in ResourceKind::all() {
362            let url = client.collection_url(*kind);
363            assert!(
364                url.contains("2025-11-01-preview"),
365                "{:?} should use preview API version, got: {}",
366                kind,
367                url
368            );
369        }
370    }
371
372    #[test]
373    fn test_retry_delay_exponential_backoff_attempt_0() {
374        let err = ClientError::ServiceUnavailable("down".to_string());
375        let delay = retry_delay(&err, 0);
376        assert_eq!(delay, Duration::from_secs(1));
377    }
378
379    #[test]
380    fn test_retry_delay_exponential_backoff_attempt_1() {
381        let err = ClientError::ServiceUnavailable("down".to_string());
382        let delay = retry_delay(&err, 1);
383        assert_eq!(delay, Duration::from_secs(2));
384    }
385
386    #[test]
387    fn test_retry_delay_exponential_backoff_attempt_2() {
388        let err = ClientError::ServiceUnavailable("down".to_string());
389        let delay = retry_delay(&err, 2);
390        assert_eq!(delay, Duration::from_secs(4));
391    }
392
393    #[test]
394    fn test_retry_delay_rate_limited_uses_retry_after() {
395        let err = ClientError::RateLimited { retry_after: 30 };
396        // retry_after should be used regardless of attempt number
397        assert_eq!(retry_delay(&err, 0), Duration::from_secs(30));
398        assert_eq!(retry_delay(&err, 1), Duration::from_secs(30));
399        assert_eq!(retry_delay(&err, 2), Duration::from_secs(30));
400    }
401
402    #[test]
403    fn test_retry_delay_rate_limited_default_retry_after() {
404        let err = ClientError::RateLimited { retry_after: 60 };
405        let delay = retry_delay(&err, 0);
406        assert_eq!(delay, Duration::from_secs(60));
407    }
408
409    #[test]
410    fn test_retry_constants() {
411        assert_eq!(MAX_RETRIES, 3);
412        assert_eq!(INITIAL_BACKOFF_SECS, 1);
413    }
414
415    #[test]
416    fn test_retry_delay_backoff_sequence() {
417        let err = ClientError::ServiceUnavailable("temporarily unavailable".to_string());
418        let delays: Vec<Duration> = (0..MAX_RETRIES).map(|i| retry_delay(&err, i)).collect();
419        assert_eq!(
420            delays,
421            vec![
422                Duration::from_secs(1),
423                Duration::from_secs(2),
424                Duration::from_secs(4),
425            ]
426        );
427    }
428
429    #[test]
430    fn test_non_retryable_error_still_computes_delay() {
431        // retry_delay computes a delay regardless; the caller decides whether to retry.
432        // This verifies the function doesn't panic on non-retryable errors.
433        let err = ClientError::Api {
434            status: 400,
435            message: "bad request".to_string(),
436        };
437        let delay = retry_delay(&err, 0);
438        assert_eq!(delay, Duration::from_secs(1));
439    }
440}