1use std::time::Duration;
4
5use reqwest::{Client, Method, StatusCode};
6use serde_json::Value;
7use tracing::{debug, instrument, warn};
8
9use hoist_core::resources::ResourceKind;
10use hoist_core::Config;
11
12use crate::auth::{get_auth_provider, AuthProvider};
13use crate::error::ClientError;
14
15const MAX_RETRIES: u32 = 3;
17
18const INITIAL_BACKOFF_SECS: u64 = 1;
20
21fn retry_delay(error: &ClientError, attempt: u32) -> Duration {
26 match error {
27 ClientError::RateLimited { retry_after } => Duration::from_secs(*retry_after),
28 _ => Duration::from_secs(INITIAL_BACKOFF_SECS * 2u64.pow(attempt)),
29 }
30}
31
32pub struct AzureSearchClient {
34 http: Client,
35 auth: Box<dyn AuthProvider>,
36 base_url: String,
37 api_version: String,
38 preview_api_version: String,
39}
40
41impl AzureSearchClient {
42 pub fn new(config: &Config) -> Result<Self, ClientError> {
44 let auth = get_auth_provider()?;
45 let http = Client::builder()
46 .timeout(std::time::Duration::from_secs(30))
47 .build()?;
48
49 Ok(Self {
50 http,
51 auth,
52 base_url: config.service_url(),
53 api_version: config.api_version_for(false).to_string(),
54 preview_api_version: config.api_version_for(true).to_string(),
55 })
56 }
57
58 pub fn new_for_server(config: &Config, server_name: &str) -> Result<Self, ClientError> {
60 let auth = get_auth_provider()?;
61 let http = Client::builder()
62 .timeout(std::time::Duration::from_secs(30))
63 .build()?;
64
65 Ok(Self {
66 http,
67 auth,
68 base_url: format!("https://{}.search.windows.net", server_name),
69 api_version: config.api_version_for(false).to_string(),
70 preview_api_version: config.api_version_for(true).to_string(),
71 })
72 }
73
74 pub fn with_auth(
76 base_url: String,
77 api_version: String,
78 preview_api_version: String,
79 auth: Box<dyn AuthProvider>,
80 ) -> Result<Self, ClientError> {
81 let http = Client::builder()
82 .timeout(std::time::Duration::from_secs(30))
83 .build()?;
84
85 Ok(Self {
86 http,
87 auth,
88 base_url,
89 api_version,
90 preview_api_version,
91 })
92 }
93
94 fn api_version_for(&self, kind: ResourceKind) -> &str {
96 if kind.is_preview() {
97 &self.preview_api_version
98 } else {
99 &self.api_version
100 }
101 }
102
103 fn collection_url(&self, kind: ResourceKind) -> String {
105 format!(
106 "{}/{}?api-version={}",
107 self.base_url,
108 kind.api_path(),
109 self.api_version_for(kind)
110 )
111 }
112
113 fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
115 format!(
116 "{}/{}/{}?api-version={}",
117 self.base_url,
118 kind.api_path(),
119 name,
120 self.api_version_for(kind)
121 )
122 }
123
124 async fn request(
126 &self,
127 method: Method,
128 url: &str,
129 body: Option<&Value>,
130 ) -> Result<Option<Value>, ClientError> {
131 let token = self.auth.get_token()?;
132
133 let mut request = self
134 .http
135 .request(method.clone(), url)
136 .header("Authorization", format!("Bearer {}", token))
137 .header("Content-Type", "application/json");
138
139 if let Some(json) = body {
140 request = request.json(json);
141 }
142
143 debug!("Request: {} {}", method, url);
144 let response = request.send().await?;
145 let status = response.status();
146
147 if status == StatusCode::NO_CONTENT {
148 return Ok(None);
149 }
150
151 let body = response.text().await?;
152
153 if status.is_success() {
154 if body.is_empty() {
155 Ok(None)
156 } else {
157 let value: Value = serde_json::from_str(&body)?;
158 Ok(Some(value))
159 }
160 } else {
161 match status {
162 StatusCode::NOT_FOUND => Err(ClientError::NotFound {
163 kind: "resource".to_string(),
164 name: url.to_string(),
165 }),
166 StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
167 kind: "resource".to_string(),
168 name: url.to_string(),
169 }),
170 StatusCode::TOO_MANY_REQUESTS => {
171 let retry_after = 60; Err(ClientError::RateLimited { retry_after })
173 }
174 StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
175 _ => Err(ClientError::from_response_with_url(
176 status.as_u16(),
177 &body,
178 Some(url),
179 )),
180 }
181 }
182 }
183
184 async fn request_with_retry(
190 &self,
191 method: Method,
192 url: &str,
193 body: Option<&Value>,
194 ) -> Result<Option<Value>, ClientError> {
195 let mut attempt = 0u32;
196 loop {
197 match self.request(method.clone(), url, body).await {
198 Ok(value) => return Ok(value),
199 Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
200 let delay = retry_delay(&err, attempt);
201 warn!(
202 "Request {} {} failed (attempt {}/{}): {}. Retrying in {:?}",
203 method,
204 url,
205 attempt + 1,
206 MAX_RETRIES + 1,
207 err,
208 delay,
209 );
210 tokio::time::sleep(delay).await;
211 attempt += 1;
212 }
213 Err(err) => return Err(err),
214 }
215 }
216 }
217
218 #[instrument(skip(self))]
220 pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
221 let url = self.collection_url(kind);
222 let response = self.request_with_retry(Method::GET, &url, None).await?;
223
224 match response {
225 Some(value) => {
226 let items = value
228 .get("value")
229 .and_then(|v| v.as_array())
230 .cloned()
231 .unwrap_or_default();
232 Ok(items)
233 }
234 None => Ok(Vec::new()),
235 }
236 }
237
238 #[instrument(skip(self))]
240 pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
241 let url = self.resource_url(kind, name);
242 let response = self.request_with_retry(Method::GET, &url, None).await?;
243
244 response.ok_or_else(|| ClientError::NotFound {
245 kind: kind.display_name().to_string(),
246 name: name.to_string(),
247 })
248 }
249
250 #[instrument(skip(self, definition))]
256 pub async fn create_or_update(
257 &self,
258 kind: ResourceKind,
259 name: &str,
260 definition: &Value,
261 ) -> Result<Option<Value>, ClientError> {
262 let url = self.resource_url(kind, name);
263 self.request_with_retry(Method::PUT, &url, Some(definition))
264 .await
265 }
266
267 #[instrument(skip(self))]
269 pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
270 let url = self.resource_url(kind, name);
271 self.request_with_retry(Method::DELETE, &url, None).await?;
272 Ok(())
273 }
274
275 pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
277 match self.get(kind, name).await {
278 Ok(_) => Ok(true),
279 Err(ClientError::NotFound { .. }) => Ok(false),
280 Err(e) => Err(e),
281 }
282 }
283
284 pub fn auth_method(&self) -> &'static str {
286 self.auth.method_name()
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use crate::auth::{AuthError, AuthProvider};
294
295 struct FakeAuth;
296 impl AuthProvider for FakeAuth {
297 fn get_token(&self) -> Result<String, AuthError> {
298 Ok("fake-token".to_string())
299 }
300 fn method_name(&self) -> &'static str {
301 "Fake"
302 }
303 }
304
305 fn make_client() -> AzureSearchClient {
306 AzureSearchClient::with_auth(
307 "https://test-svc.search.windows.net".to_string(),
308 "2024-07-01".to_string(),
309 "2025-11-01-preview".to_string(),
310 Box::new(FakeAuth),
311 )
312 .unwrap()
313 }
314
315 #[test]
316 fn test_collection_url_stable_resource() {
317 let client = make_client();
318 let url = client.collection_url(ResourceKind::Index);
319 assert_eq!(
320 url,
321 "https://test-svc.search.windows.net/indexes?api-version=2024-07-01"
322 );
323 }
324
325 #[test]
326 fn test_collection_url_preview_resource_uses_preview_version() {
327 let client = make_client();
328 let url = client.collection_url(ResourceKind::KnowledgeBase);
329 assert_eq!(
330 url,
331 "https://test-svc.search.windows.net/knowledgebases?api-version=2025-11-01-preview"
332 );
333 }
334
335 #[test]
336 fn test_collection_url_knowledge_source_uses_preview_version() {
337 let client = make_client();
338 let url = client.collection_url(ResourceKind::KnowledgeSource);
339 assert_eq!(
340 url,
341 "https://test-svc.search.windows.net/knowledgesources?api-version=2025-11-01-preview"
342 );
343 }
344
345 #[test]
346 fn test_resource_url_stable() {
347 let client = make_client();
348 let url = client.resource_url(ResourceKind::Index, "my-index");
349 assert_eq!(
350 url,
351 "https://test-svc.search.windows.net/indexes/my-index?api-version=2024-07-01"
352 );
353 }
354
355 #[test]
356 fn test_resource_url_preview() {
357 let client = make_client();
358 let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
359 assert_eq!(
360 url,
361 "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-11-01-preview"
362 );
363 }
364
365 #[test]
366 fn test_all_stable_kinds_use_stable_version() {
367 let client = make_client();
368 for kind in ResourceKind::stable() {
369 let url = client.collection_url(*kind);
370 assert!(
371 url.contains("2024-07-01"),
372 "{:?} should use stable API version, got: {}",
373 kind,
374 url
375 );
376 }
377 }
378
379 #[test]
380 fn test_new_for_server_produces_correct_base_url() {
381 let client = AzureSearchClient::with_auth(
384 "https://other-svc.search.windows.net".to_string(),
385 "2024-07-01".to_string(),
386 "2025-11-01-preview".to_string(),
387 Box::new(FakeAuth),
388 )
389 .unwrap();
390 let url = client.collection_url(ResourceKind::Index);
391 assert_eq!(
392 url,
393 "https://other-svc.search.windows.net/indexes?api-version=2024-07-01"
394 );
395 }
396
397 #[test]
398 fn test_all_preview_kinds_use_preview_version() {
399 let client = make_client();
400 for kind in ResourceKind::all() {
401 if kind.is_preview() {
402 let url = client.collection_url(*kind);
403 assert!(
404 url.contains("2025-11-01-preview"),
405 "{:?} should use preview API version, got: {}",
406 kind,
407 url
408 );
409 }
410 }
411 }
412
413 #[test]
414 fn test_retry_delay_exponential_backoff_attempt_0() {
415 let err = ClientError::ServiceUnavailable("down".to_string());
416 let delay = retry_delay(&err, 0);
417 assert_eq!(delay, Duration::from_secs(1));
418 }
419
420 #[test]
421 fn test_retry_delay_exponential_backoff_attempt_1() {
422 let err = ClientError::ServiceUnavailable("down".to_string());
423 let delay = retry_delay(&err, 1);
424 assert_eq!(delay, Duration::from_secs(2));
425 }
426
427 #[test]
428 fn test_retry_delay_exponential_backoff_attempt_2() {
429 let err = ClientError::ServiceUnavailable("down".to_string());
430 let delay = retry_delay(&err, 2);
431 assert_eq!(delay, Duration::from_secs(4));
432 }
433
434 #[test]
435 fn test_retry_delay_rate_limited_uses_retry_after() {
436 let err = ClientError::RateLimited { retry_after: 30 };
437 assert_eq!(retry_delay(&err, 0), Duration::from_secs(30));
439 assert_eq!(retry_delay(&err, 1), Duration::from_secs(30));
440 assert_eq!(retry_delay(&err, 2), Duration::from_secs(30));
441 }
442
443 #[test]
444 fn test_retry_delay_rate_limited_default_retry_after() {
445 let err = ClientError::RateLimited { retry_after: 60 };
446 let delay = retry_delay(&err, 0);
447 assert_eq!(delay, Duration::from_secs(60));
448 }
449
450 #[test]
451 fn test_retry_constants() {
452 assert_eq!(MAX_RETRIES, 3);
453 assert_eq!(INITIAL_BACKOFF_SECS, 1);
454 }
455
456 #[test]
457 fn test_retry_delay_backoff_sequence() {
458 let err = ClientError::ServiceUnavailable("temporarily unavailable".to_string());
459 let delays: Vec<Duration> = (0..MAX_RETRIES).map(|i| retry_delay(&err, i)).collect();
460 assert_eq!(
461 delays,
462 vec![
463 Duration::from_secs(1),
464 Duration::from_secs(2),
465 Duration::from_secs(4),
466 ]
467 );
468 }
469
470 #[test]
471 fn test_non_retryable_error_still_computes_delay() {
472 let err = ClientError::Api {
475 status: 400,
476 message: "bad request".to_string(),
477 };
478 let delay = retry_delay(&err, 0);
479 assert_eq!(delay, Duration::from_secs(1));
480 }
481}