1use std::time::Duration;
4
5use reqwest::{Client, Method, StatusCode};
6use serde_json::Value;
7use tracing::{debug, instrument, warn};
8
9use hoist_core::resources::ResourceKind;
10use hoist_core::Config;
11
12use crate::auth::{get_auth_provider, AuthProvider};
13use crate::error::ClientError;
14
15const MAX_RETRIES: u32 = 3;
17
18const INITIAL_BACKOFF_SECS: u64 = 1;
20
21fn retry_delay(error: &ClientError, attempt: u32) -> Duration {
26 match error {
27 ClientError::RateLimited { retry_after } => Duration::from_secs(*retry_after),
28 _ => Duration::from_secs(INITIAL_BACKOFF_SECS * 2u64.pow(attempt)),
29 }
30}
31
32pub struct AzureSearchClient {
34 http: Client,
35 auth: Box<dyn AuthProvider>,
36 base_url: String,
37 preview_api_version: String,
38}
39
40impl AzureSearchClient {
41 pub fn new(config: &Config) -> Result<Self, ClientError> {
43 let auth = get_auth_provider()?;
44 let http = Client::builder()
45 .timeout(std::time::Duration::from_secs(30))
46 .build()?;
47
48 Ok(Self {
49 http,
50 auth,
51 base_url: config.service_url(),
52 preview_api_version: config.api_version_for(true).to_string(),
53 })
54 }
55
56 pub fn new_for_server(config: &Config, server_name: &str) -> Result<Self, ClientError> {
58 let auth = get_auth_provider()?;
59 let http = Client::builder()
60 .timeout(std::time::Duration::from_secs(30))
61 .build()?;
62
63 Ok(Self {
64 http,
65 auth,
66 base_url: format!("https://{}.search.windows.net", server_name),
67 preview_api_version: config.api_version_for(true).to_string(),
68 })
69 }
70
71 pub fn with_auth(
73 base_url: String,
74 preview_api_version: String,
75 auth: Box<dyn AuthProvider>,
76 ) -> Result<Self, ClientError> {
77 let http = Client::builder()
78 .timeout(std::time::Duration::from_secs(30))
79 .build()?;
80
81 Ok(Self {
82 http,
83 auth,
84 base_url,
85 preview_api_version,
86 })
87 }
88
89 fn api_version_for(&self, _kind: ResourceKind) -> &str {
94 &self.preview_api_version
95 }
96
97 fn collection_url(&self, kind: ResourceKind) -> String {
99 format!(
100 "{}/{}?api-version={}",
101 self.base_url,
102 kind.api_path(),
103 self.api_version_for(kind)
104 )
105 }
106
107 fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
109 format!(
110 "{}/{}/{}?api-version={}",
111 self.base_url,
112 kind.api_path(),
113 name,
114 self.api_version_for(kind)
115 )
116 }
117
118 async fn request(
120 &self,
121 method: Method,
122 url: &str,
123 body: Option<&Value>,
124 ) -> Result<Option<Value>, ClientError> {
125 let token = self.auth.get_token()?;
126
127 let mut request = self
128 .http
129 .request(method.clone(), url)
130 .header("Authorization", format!("Bearer {}", token))
131 .header("Content-Type", "application/json");
132
133 if let Some(json) = body {
134 request = request.json(json);
135 }
136
137 debug!("Request: {} {}", method, url);
138 let response = request.send().await?;
139 let status = response.status();
140
141 if status == StatusCode::NO_CONTENT {
142 return Ok(None);
143 }
144
145 let body = response.text().await?;
146
147 if status.is_success() {
148 if body.is_empty() {
149 Ok(None)
150 } else {
151 let value: Value = serde_json::from_str(&body)?;
152 Ok(Some(value))
153 }
154 } else {
155 match status {
156 StatusCode::NOT_FOUND => Err(ClientError::NotFound {
157 kind: "resource".to_string(),
158 name: url.to_string(),
159 }),
160 StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
161 kind: "resource".to_string(),
162 name: url.to_string(),
163 }),
164 StatusCode::TOO_MANY_REQUESTS => {
165 let retry_after = 60; Err(ClientError::RateLimited { retry_after })
167 }
168 StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
169 _ => Err(ClientError::from_response_with_url(
170 status.as_u16(),
171 &body,
172 Some(url),
173 )),
174 }
175 }
176 }
177
178 async fn request_with_retry(
184 &self,
185 method: Method,
186 url: &str,
187 body: Option<&Value>,
188 ) -> Result<Option<Value>, ClientError> {
189 let mut attempt = 0u32;
190 loop {
191 match self.request(method.clone(), url, body).await {
192 Ok(value) => return Ok(value),
193 Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
194 let delay = retry_delay(&err, attempt);
195 warn!(
196 "Request {} {} failed (attempt {}/{}): {}. Retrying in {:?}",
197 method,
198 url,
199 attempt + 1,
200 MAX_RETRIES + 1,
201 err,
202 delay,
203 );
204 tokio::time::sleep(delay).await;
205 attempt += 1;
206 }
207 Err(err) => return Err(err),
208 }
209 }
210 }
211
212 #[instrument(skip(self))]
214 pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
215 let url = self.collection_url(kind);
216 let response = self.request_with_retry(Method::GET, &url, None).await?;
217
218 match response {
219 Some(value) => {
220 let items = value
222 .get("value")
223 .and_then(|v| v.as_array())
224 .cloned()
225 .unwrap_or_default();
226 Ok(items)
227 }
228 None => Ok(Vec::new()),
229 }
230 }
231
232 #[instrument(skip(self))]
234 pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
235 let url = self.resource_url(kind, name);
236 let response = self.request_with_retry(Method::GET, &url, None).await?;
237
238 response.ok_or_else(|| ClientError::NotFound {
239 kind: kind.display_name().to_string(),
240 name: name.to_string(),
241 })
242 }
243
244 #[instrument(skip(self, definition))]
250 pub async fn create_or_update(
251 &self,
252 kind: ResourceKind,
253 name: &str,
254 definition: &Value,
255 ) -> Result<Option<Value>, ClientError> {
256 let url = self.resource_url(kind, name);
257 self.request_with_retry(Method::PUT, &url, Some(definition))
258 .await
259 }
260
261 #[instrument(skip(self))]
263 pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
264 let url = self.resource_url(kind, name);
265 self.request_with_retry(Method::DELETE, &url, None).await?;
266 Ok(())
267 }
268
269 pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
271 match self.get(kind, name).await {
272 Ok(_) => Ok(true),
273 Err(ClientError::NotFound { .. }) => Ok(false),
274 Err(e) => Err(e),
275 }
276 }
277
278 pub fn auth_method(&self) -> &'static str {
280 self.auth.method_name()
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use crate::auth::{AuthError, AuthProvider};
288
289 struct FakeAuth;
290 impl AuthProvider for FakeAuth {
291 fn get_token(&self) -> Result<String, AuthError> {
292 Ok("fake-token".to_string())
293 }
294 fn method_name(&self) -> &'static str {
295 "Fake"
296 }
297 }
298
299 fn make_client() -> AzureSearchClient {
300 AzureSearchClient::with_auth(
301 "https://test-svc.search.windows.net".to_string(),
302 "2025-11-01-preview".to_string(),
303 Box::new(FakeAuth),
304 )
305 .unwrap()
306 }
307
308 #[test]
309 fn test_collection_url_uses_preview_version() {
310 let client = make_client();
311 let url = client.collection_url(ResourceKind::Index);
312 assert_eq!(
313 url,
314 "https://test-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
315 );
316 }
317
318 #[test]
319 fn test_collection_url_knowledge_base_uses_preview_version() {
320 let client = make_client();
321 let url = client.collection_url(ResourceKind::KnowledgeBase);
322 assert_eq!(
323 url,
324 "https://test-svc.search.windows.net/knowledgebases?api-version=2025-11-01-preview"
325 );
326 }
327
328 #[test]
329 fn test_collection_url_knowledge_source_uses_preview_version() {
330 let client = make_client();
331 let url = client.collection_url(ResourceKind::KnowledgeSource);
332 assert_eq!(
333 url,
334 "https://test-svc.search.windows.net/knowledgesources?api-version=2025-11-01-preview"
335 );
336 }
337
338 #[test]
339 fn test_resource_url_uses_preview_version() {
340 let client = make_client();
341 let url = client.resource_url(ResourceKind::Index, "my-index");
342 assert_eq!(
343 url,
344 "https://test-svc.search.windows.net/indexes/my-index?api-version=2025-11-01-preview"
345 );
346 }
347
348 #[test]
349 fn test_resource_url_knowledge_base_uses_preview_version() {
350 let client = make_client();
351 let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
352 assert_eq!(
353 url,
354 "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-11-01-preview"
355 );
356 }
357
358 #[test]
359 fn test_new_for_server_produces_correct_base_url() {
360 let client = AzureSearchClient::with_auth(
363 "https://other-svc.search.windows.net".to_string(),
364 "2025-11-01-preview".to_string(),
365 Box::new(FakeAuth),
366 )
367 .unwrap();
368 let url = client.collection_url(ResourceKind::Index);
369 assert_eq!(
370 url,
371 "https://other-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
372 );
373 }
374
375 #[test]
376 fn test_all_kinds_use_preview_version() {
377 let client = make_client();
378 for kind in ResourceKind::all() {
379 let url = client.collection_url(*kind);
380 assert!(
381 url.contains("2025-11-01-preview"),
382 "{:?} should use preview API version, got: {}",
383 kind,
384 url
385 );
386 }
387 }
388
389 #[test]
390 fn test_retry_delay_exponential_backoff_attempt_0() {
391 let err = ClientError::ServiceUnavailable("down".to_string());
392 let delay = retry_delay(&err, 0);
393 assert_eq!(delay, Duration::from_secs(1));
394 }
395
396 #[test]
397 fn test_retry_delay_exponential_backoff_attempt_1() {
398 let err = ClientError::ServiceUnavailable("down".to_string());
399 let delay = retry_delay(&err, 1);
400 assert_eq!(delay, Duration::from_secs(2));
401 }
402
403 #[test]
404 fn test_retry_delay_exponential_backoff_attempt_2() {
405 let err = ClientError::ServiceUnavailable("down".to_string());
406 let delay = retry_delay(&err, 2);
407 assert_eq!(delay, Duration::from_secs(4));
408 }
409
410 #[test]
411 fn test_retry_delay_rate_limited_uses_retry_after() {
412 let err = ClientError::RateLimited { retry_after: 30 };
413 assert_eq!(retry_delay(&err, 0), Duration::from_secs(30));
415 assert_eq!(retry_delay(&err, 1), Duration::from_secs(30));
416 assert_eq!(retry_delay(&err, 2), Duration::from_secs(30));
417 }
418
419 #[test]
420 fn test_retry_delay_rate_limited_default_retry_after() {
421 let err = ClientError::RateLimited { retry_after: 60 };
422 let delay = retry_delay(&err, 0);
423 assert_eq!(delay, Duration::from_secs(60));
424 }
425
426 #[test]
427 fn test_retry_constants() {
428 assert_eq!(MAX_RETRIES, 3);
429 assert_eq!(INITIAL_BACKOFF_SECS, 1);
430 }
431
432 #[test]
433 fn test_retry_delay_backoff_sequence() {
434 let err = ClientError::ServiceUnavailable("temporarily unavailable".to_string());
435 let delays: Vec<Duration> = (0..MAX_RETRIES).map(|i| retry_delay(&err, i)).collect();
436 assert_eq!(
437 delays,
438 vec![
439 Duration::from_secs(1),
440 Duration::from_secs(2),
441 Duration::from_secs(4),
442 ]
443 );
444 }
445
446 #[test]
447 fn test_non_retryable_error_still_computes_delay() {
448 let err = ClientError::Api {
451 status: 400,
452 message: "bad request".to_string(),
453 };
454 let delay = retry_delay(&err, 0);
455 assert_eq!(delay, Duration::from_secs(1));
456 }
457}