1use std::time::Duration;
4
5use reqwest::{Client, Method, StatusCode};
6use serde_json::Value;
7use tracing::{debug, instrument, warn};
8
9use hoist_core::resources::ResourceKind;
10use hoist_core::Config;
11
12use crate::auth::{get_auth_provider, AuthProvider};
13use crate::error::ClientError;
14
15const MAX_RETRIES: u32 = 3;
17
18const INITIAL_BACKOFF_SECS: u64 = 1;
20
21fn retry_delay(error: &ClientError, attempt: u32) -> Duration {
26 match error {
27 ClientError::RateLimited { retry_after } => Duration::from_secs(*retry_after),
28 _ => Duration::from_secs(INITIAL_BACKOFF_SECS * 2u64.pow(attempt)),
29 }
30}
31
32pub struct AzureSearchClient {
34 http: Client,
35 auth: Box<dyn AuthProvider>,
36 base_url: String,
37 api_version: String,
38 preview_api_version: String,
39}
40
41impl AzureSearchClient {
42 pub fn new(config: &Config) -> Result<Self, ClientError> {
44 let auth = get_auth_provider()?;
45 let http = Client::builder()
46 .timeout(std::time::Duration::from_secs(30))
47 .build()?;
48
49 Ok(Self {
50 http,
51 auth,
52 base_url: config.service_url(),
53 api_version: config.api_version_for(false).to_string(),
54 preview_api_version: config.api_version_for(true).to_string(),
55 })
56 }
57
58 pub fn new_for_server(config: &Config, server_name: &str) -> Result<Self, ClientError> {
60 let auth = get_auth_provider()?;
61 let http = Client::builder()
62 .timeout(std::time::Duration::from_secs(30))
63 .build()?;
64
65 Ok(Self {
66 http,
67 auth,
68 base_url: format!("https://{}.search.windows.net", server_name),
69 api_version: config.api_version_for(false).to_string(),
70 preview_api_version: config.api_version_for(true).to_string(),
71 })
72 }
73
74 pub fn with_auth(
76 base_url: String,
77 api_version: String,
78 preview_api_version: String,
79 auth: Box<dyn AuthProvider>,
80 ) -> Result<Self, ClientError> {
81 let http = Client::builder()
82 .timeout(std::time::Duration::from_secs(30))
83 .build()?;
84
85 Ok(Self {
86 http,
87 auth,
88 base_url,
89 api_version,
90 preview_api_version,
91 })
92 }
93
94 fn api_version_for(&self, kind: ResourceKind) -> &str {
96 if kind.is_preview() {
97 &self.preview_api_version
98 } else {
99 &self.api_version
100 }
101 }
102
103 fn collection_url(&self, kind: ResourceKind) -> String {
105 format!(
106 "{}/{}?api-version={}",
107 self.base_url,
108 kind.api_path(),
109 self.api_version_for(kind)
110 )
111 }
112
113 fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
115 format!(
116 "{}/{}/{}?api-version={}",
117 self.base_url,
118 kind.api_path(),
119 name,
120 self.api_version_for(kind)
121 )
122 }
123
124 async fn request(
126 &self,
127 method: Method,
128 url: &str,
129 body: Option<&Value>,
130 ) -> Result<Option<Value>, ClientError> {
131 let token = self.auth.get_token()?;
132
133 let mut request = self
134 .http
135 .request(method.clone(), url)
136 .header("Authorization", format!("Bearer {}", token))
137 .header("Content-Type", "application/json");
138
139 if let Some(json) = body {
140 request = request.json(json);
141 }
142
143 debug!("Request: {} {}", method, url);
144 let response = request.send().await?;
145 let status = response.status();
146
147 if status == StatusCode::NO_CONTENT {
148 return Ok(None);
149 }
150
151 let body = response.text().await?;
152
153 if status.is_success() {
154 if body.is_empty() {
155 Ok(None)
156 } else {
157 let value: Value = serde_json::from_str(&body)?;
158 Ok(Some(value))
159 }
160 } else {
161 match status {
162 StatusCode::NOT_FOUND => Err(ClientError::NotFound {
163 kind: "resource".to_string(),
164 name: url.to_string(),
165 }),
166 StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
167 kind: "resource".to_string(),
168 name: url.to_string(),
169 }),
170 StatusCode::TOO_MANY_REQUESTS => {
171 let retry_after = 60; Err(ClientError::RateLimited { retry_after })
173 }
174 StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
175 _ => Err(ClientError::from_response(status.as_u16(), &body)),
176 }
177 }
178 }
179
180 async fn request_with_retry(
186 &self,
187 method: Method,
188 url: &str,
189 body: Option<&Value>,
190 ) -> Result<Option<Value>, ClientError> {
191 let mut attempt = 0u32;
192 loop {
193 match self.request(method.clone(), url, body).await {
194 Ok(value) => return Ok(value),
195 Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
196 let delay = retry_delay(&err, attempt);
197 warn!(
198 "Request {} {} failed (attempt {}/{}): {}. Retrying in {:?}",
199 method,
200 url,
201 attempt + 1,
202 MAX_RETRIES + 1,
203 err,
204 delay,
205 );
206 tokio::time::sleep(delay).await;
207 attempt += 1;
208 }
209 Err(err) => return Err(err),
210 }
211 }
212 }
213
214 #[instrument(skip(self))]
216 pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
217 let url = self.collection_url(kind);
218 let response = self.request_with_retry(Method::GET, &url, None).await?;
219
220 match response {
221 Some(value) => {
222 let items = value
224 .get("value")
225 .and_then(|v| v.as_array())
226 .cloned()
227 .unwrap_or_default();
228 Ok(items)
229 }
230 None => Ok(Vec::new()),
231 }
232 }
233
234 #[instrument(skip(self))]
236 pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
237 let url = self.resource_url(kind, name);
238 let response = self.request_with_retry(Method::GET, &url, None).await?;
239
240 response.ok_or_else(|| ClientError::NotFound {
241 kind: kind.display_name().to_string(),
242 name: name.to_string(),
243 })
244 }
245
246 #[instrument(skip(self, definition))]
252 pub async fn create_or_update(
253 &self,
254 kind: ResourceKind,
255 name: &str,
256 definition: &Value,
257 ) -> Result<Option<Value>, ClientError> {
258 let url = self.resource_url(kind, name);
259 self.request_with_retry(Method::PUT, &url, Some(definition))
260 .await
261 }
262
263 #[instrument(skip(self))]
265 pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
266 let url = self.resource_url(kind, name);
267 self.request_with_retry(Method::DELETE, &url, None).await?;
268 Ok(())
269 }
270
271 pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
273 match self.get(kind, name).await {
274 Ok(_) => Ok(true),
275 Err(ClientError::NotFound { .. }) => Ok(false),
276 Err(e) => Err(e),
277 }
278 }
279
280 pub fn auth_method(&self) -> &'static str {
282 self.auth.method_name()
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use crate::auth::{AuthError, AuthProvider};
290
291 struct FakeAuth;
292 impl AuthProvider for FakeAuth {
293 fn get_token(&self) -> Result<String, AuthError> {
294 Ok("fake-token".to_string())
295 }
296 fn method_name(&self) -> &'static str {
297 "Fake"
298 }
299 }
300
301 fn make_client() -> AzureSearchClient {
302 AzureSearchClient::with_auth(
303 "https://test-svc.search.windows.net".to_string(),
304 "2024-07-01".to_string(),
305 "2025-11-01-preview".to_string(),
306 Box::new(FakeAuth),
307 )
308 .unwrap()
309 }
310
311 #[test]
312 fn test_collection_url_stable_resource() {
313 let client = make_client();
314 let url = client.collection_url(ResourceKind::Index);
315 assert_eq!(
316 url,
317 "https://test-svc.search.windows.net/indexes?api-version=2024-07-01"
318 );
319 }
320
321 #[test]
322 fn test_collection_url_preview_resource_uses_preview_version() {
323 let client = make_client();
324 let url = client.collection_url(ResourceKind::KnowledgeBase);
325 assert_eq!(
326 url,
327 "https://test-svc.search.windows.net/knowledgebases?api-version=2025-11-01-preview"
328 );
329 }
330
331 #[test]
332 fn test_collection_url_knowledge_source_uses_preview_version() {
333 let client = make_client();
334 let url = client.collection_url(ResourceKind::KnowledgeSource);
335 assert_eq!(
336 url,
337 "https://test-svc.search.windows.net/knowledgesources?api-version=2025-11-01-preview"
338 );
339 }
340
341 #[test]
342 fn test_resource_url_stable() {
343 let client = make_client();
344 let url = client.resource_url(ResourceKind::Index, "my-index");
345 assert_eq!(
346 url,
347 "https://test-svc.search.windows.net/indexes/my-index?api-version=2024-07-01"
348 );
349 }
350
351 #[test]
352 fn test_resource_url_preview() {
353 let client = make_client();
354 let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
355 assert_eq!(
356 url,
357 "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-11-01-preview"
358 );
359 }
360
361 #[test]
362 fn test_all_stable_kinds_use_stable_version() {
363 let client = make_client();
364 for kind in ResourceKind::stable() {
365 let url = client.collection_url(*kind);
366 assert!(
367 url.contains("2024-07-01"),
368 "{:?} should use stable API version, got: {}",
369 kind,
370 url
371 );
372 }
373 }
374
375 #[test]
376 fn test_new_for_server_produces_correct_base_url() {
377 let client = AzureSearchClient::with_auth(
380 "https://other-svc.search.windows.net".to_string(),
381 "2024-07-01".to_string(),
382 "2025-11-01-preview".to_string(),
383 Box::new(FakeAuth),
384 )
385 .unwrap();
386 let url = client.collection_url(ResourceKind::Index);
387 assert_eq!(
388 url,
389 "https://other-svc.search.windows.net/indexes?api-version=2024-07-01"
390 );
391 }
392
393 #[test]
394 fn test_all_preview_kinds_use_preview_version() {
395 let client = make_client();
396 for kind in ResourceKind::all() {
397 if kind.is_preview() {
398 let url = client.collection_url(*kind);
399 assert!(
400 url.contains("2025-11-01-preview"),
401 "{:?} should use preview API version, got: {}",
402 kind,
403 url
404 );
405 }
406 }
407 }
408
409 #[test]
410 fn test_retry_delay_exponential_backoff_attempt_0() {
411 let err = ClientError::ServiceUnavailable("down".to_string());
412 let delay = retry_delay(&err, 0);
413 assert_eq!(delay, Duration::from_secs(1));
414 }
415
416 #[test]
417 fn test_retry_delay_exponential_backoff_attempt_1() {
418 let err = ClientError::ServiceUnavailable("down".to_string());
419 let delay = retry_delay(&err, 1);
420 assert_eq!(delay, Duration::from_secs(2));
421 }
422
423 #[test]
424 fn test_retry_delay_exponential_backoff_attempt_2() {
425 let err = ClientError::ServiceUnavailable("down".to_string());
426 let delay = retry_delay(&err, 2);
427 assert_eq!(delay, Duration::from_secs(4));
428 }
429
430 #[test]
431 fn test_retry_delay_rate_limited_uses_retry_after() {
432 let err = ClientError::RateLimited { retry_after: 30 };
433 assert_eq!(retry_delay(&err, 0), Duration::from_secs(30));
435 assert_eq!(retry_delay(&err, 1), Duration::from_secs(30));
436 assert_eq!(retry_delay(&err, 2), Duration::from_secs(30));
437 }
438
439 #[test]
440 fn test_retry_delay_rate_limited_default_retry_after() {
441 let err = ClientError::RateLimited { retry_after: 60 };
442 let delay = retry_delay(&err, 0);
443 assert_eq!(delay, Duration::from_secs(60));
444 }
445
446 #[test]
447 fn test_retry_constants() {
448 assert_eq!(MAX_RETRIES, 3);
449 assert_eq!(INITIAL_BACKOFF_SECS, 1);
450 }
451
452 #[test]
453 fn test_retry_delay_backoff_sequence() {
454 let err = ClientError::ServiceUnavailable("temporarily unavailable".to_string());
455 let delays: Vec<Duration> = (0..MAX_RETRIES).map(|i| retry_delay(&err, i)).collect();
456 assert_eq!(
457 delays,
458 vec![
459 Duration::from_secs(1),
460 Duration::from_secs(2),
461 Duration::from_secs(4),
462 ]
463 );
464 }
465
466 #[test]
467 fn test_non_retryable_error_still_computes_delay() {
468 let err = ClientError::Api {
471 status: 400,
472 message: "bad request".to_string(),
473 };
474 let delay = retry_delay(&err, 0);
475 assert_eq!(delay, Duration::from_secs(1));
476 }
477}