1use std::time::Duration;
4
5use reqwest::{Client, Method, StatusCode};
6use serde_json::Value;
7use tracing::{debug, instrument, warn};
8
9use hoist_core::config::SearchServiceConfig;
10use hoist_core::resources::ResourceKind;
11
12use crate::auth::{AuthProvider, get_auth_provider};
13use crate::error::ClientError;
14
15const MAX_RETRIES: u32 = 3;
17
18const INITIAL_BACKOFF_SECS: u64 = 1;
20
21fn retry_delay(error: &ClientError, attempt: u32) -> Duration {
26 match error {
27 ClientError::RateLimited { retry_after } => Duration::from_secs(*retry_after),
28 _ => Duration::from_secs(INITIAL_BACKOFF_SECS * 2u64.pow(attempt)),
29 }
30}
31
32pub struct AzureSearchClient {
34 http: Client,
35 auth: Box<dyn AuthProvider>,
36 base_url: String,
37 preview_api_version: String,
38}
39
40impl AzureSearchClient {
41 pub fn from_service_config(service: &SearchServiceConfig) -> Result<Self, ClientError> {
43 let auth = get_auth_provider()?;
44 let http = Client::builder().timeout(Duration::from_secs(30)).build()?;
45
46 Ok(Self {
47 http,
48 auth,
49 base_url: service.service_url(),
50 preview_api_version: service.preview_api_version.clone(),
51 })
52 }
53
54 pub fn with_auth(
56 base_url: String,
57 preview_api_version: String,
58 auth: Box<dyn AuthProvider>,
59 ) -> Result<Self, ClientError> {
60 let http = Client::builder()
61 .timeout(std::time::Duration::from_secs(30))
62 .build()?;
63
64 Ok(Self {
65 http,
66 auth,
67 base_url,
68 preview_api_version,
69 })
70 }
71
72 fn api_version_for(&self, _kind: ResourceKind) -> &str {
77 &self.preview_api_version
78 }
79
80 fn collection_url(&self, kind: ResourceKind) -> String {
82 format!(
83 "{}/{}?api-version={}",
84 self.base_url,
85 kind.api_path(),
86 self.api_version_for(kind)
87 )
88 }
89
90 fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
92 format!(
93 "{}/{}/{}?api-version={}",
94 self.base_url,
95 kind.api_path(),
96 urlencoding::encode(name),
97 self.api_version_for(kind)
98 )
99 }
100
101 async fn request(
103 &self,
104 method: Method,
105 url: &str,
106 body: Option<&Value>,
107 ) -> Result<Option<Value>, ClientError> {
108 let token = self.auth.get_token()?;
109
110 let mut request = self
111 .http
112 .request(method.clone(), url)
113 .header("Authorization", format!("Bearer {}", token))
114 .header("Content-Type", "application/json");
115
116 if let Some(json) = body {
117 request = request.json(json);
118 }
119
120 debug!("Request: {} {}", method, url);
121 let response = request.send().await?;
122 let status = response.status();
123
124 if status == StatusCode::NO_CONTENT {
125 return Ok(None);
126 }
127
128 let body = response.text().await?;
129
130 if status.is_success() {
131 if body.is_empty() {
132 Ok(None)
133 } else {
134 let value: Value = serde_json::from_str(&body)?;
135 Ok(Some(value))
136 }
137 } else {
138 match status {
139 StatusCode::NOT_FOUND => Err(ClientError::NotFound {
140 kind: "resource".to_string(),
141 name: url.to_string(),
142 }),
143 StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
144 kind: "resource".to_string(),
145 name: url.to_string(),
146 }),
147 StatusCode::TOO_MANY_REQUESTS => {
148 let retry_after = 60; Err(ClientError::RateLimited { retry_after })
150 }
151 StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
152 _ => Err(ClientError::from_response_with_url(
153 status.as_u16(),
154 &body,
155 Some(url),
156 )),
157 }
158 }
159 }
160
161 async fn request_with_retry(
167 &self,
168 method: Method,
169 url: &str,
170 body: Option<&Value>,
171 ) -> Result<Option<Value>, ClientError> {
172 let mut attempt = 0u32;
173 loop {
174 match self.request(method.clone(), url, body).await {
175 Ok(value) => return Ok(value),
176 Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
177 let delay = retry_delay(&err, attempt);
178 warn!(
179 "Request {} {} failed (attempt {}/{}): {}. Retrying in {:?}",
180 method,
181 url,
182 attempt + 1,
183 MAX_RETRIES + 1,
184 err,
185 delay,
186 );
187 tokio::time::sleep(delay).await;
188 attempt += 1;
189 }
190 Err(err) => return Err(err),
191 }
192 }
193 }
194
195 #[instrument(skip(self))]
197 pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
198 let url = self.collection_url(kind);
199 let response = self.request_with_retry(Method::GET, &url, None).await?;
200
201 match response {
202 Some(value) => {
203 let items = value
205 .get("value")
206 .and_then(|v| v.as_array())
207 .cloned()
208 .unwrap_or_default();
209 Ok(items)
210 }
211 None => Ok(Vec::new()),
212 }
213 }
214
215 #[instrument(skip(self))]
217 pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
218 let url = self.resource_url(kind, name);
219 let response = self.request_with_retry(Method::GET, &url, None).await?;
220
221 response.ok_or_else(|| ClientError::NotFound {
222 kind: kind.display_name().to_string(),
223 name: name.to_string(),
224 })
225 }
226
227 #[instrument(skip(self, definition))]
233 pub async fn create_or_update(
234 &self,
235 kind: ResourceKind,
236 name: &str,
237 definition: &Value,
238 ) -> Result<Option<Value>, ClientError> {
239 let url = self.resource_url(kind, name);
240 self.request_with_retry(Method::PUT, &url, Some(definition))
241 .await
242 }
243
244 #[instrument(skip(self))]
246 pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
247 let url = self.resource_url(kind, name);
248 self.request_with_retry(Method::DELETE, &url, None).await?;
249 Ok(())
250 }
251
252 pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
254 match self.get(kind, name).await {
255 Ok(_) => Ok(true),
256 Err(ClientError::NotFound { .. }) => Ok(false),
257 Err(e) => Err(e),
258 }
259 }
260
261 pub fn auth_method(&self) -> &'static str {
263 self.auth.method_name()
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use crate::auth::{AuthError, AuthProvider};
271
272 struct FakeAuth;
273 impl AuthProvider for FakeAuth {
274 fn get_token(&self) -> Result<String, AuthError> {
275 Ok("fake-token".to_string())
276 }
277 fn method_name(&self) -> &'static str {
278 "Fake"
279 }
280 }
281
282 fn make_client() -> AzureSearchClient {
283 AzureSearchClient::with_auth(
284 "https://test-svc.search.windows.net".to_string(),
285 "2025-11-01-preview".to_string(),
286 Box::new(FakeAuth),
287 )
288 .unwrap()
289 }
290
291 #[test]
292 fn test_collection_url_uses_preview_version() {
293 let client = make_client();
294 let url = client.collection_url(ResourceKind::Index);
295 assert_eq!(
296 url,
297 "https://test-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
298 );
299 }
300
301 #[test]
302 fn test_collection_url_knowledge_base_uses_preview_version() {
303 let client = make_client();
304 let url = client.collection_url(ResourceKind::KnowledgeBase);
305 assert_eq!(
306 url,
307 "https://test-svc.search.windows.net/knowledgebases?api-version=2025-11-01-preview"
308 );
309 }
310
311 #[test]
312 fn test_collection_url_knowledge_source_uses_preview_version() {
313 let client = make_client();
314 let url = client.collection_url(ResourceKind::KnowledgeSource);
315 assert_eq!(
316 url,
317 "https://test-svc.search.windows.net/knowledgesources?api-version=2025-11-01-preview"
318 );
319 }
320
321 #[test]
322 fn test_resource_url_uses_preview_version() {
323 let client = make_client();
324 let url = client.resource_url(ResourceKind::Index, "my-index");
325 assert_eq!(
326 url,
327 "https://test-svc.search.windows.net/indexes/my-index?api-version=2025-11-01-preview"
328 );
329 }
330
331 #[test]
332 fn test_resource_url_knowledge_base_uses_preview_version() {
333 let client = make_client();
334 let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
335 assert_eq!(
336 url,
337 "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-11-01-preview"
338 );
339 }
340
341 #[test]
342 fn test_new_for_server_produces_correct_base_url() {
343 let client = AzureSearchClient::with_auth(
346 "https://other-svc.search.windows.net".to_string(),
347 "2025-11-01-preview".to_string(),
348 Box::new(FakeAuth),
349 )
350 .unwrap();
351 let url = client.collection_url(ResourceKind::Index);
352 assert_eq!(
353 url,
354 "https://other-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
355 );
356 }
357
358 #[test]
359 fn test_all_kinds_use_preview_version() {
360 let client = make_client();
361 for kind in ResourceKind::all() {
362 let url = client.collection_url(*kind);
363 assert!(
364 url.contains("2025-11-01-preview"),
365 "{:?} should use preview API version, got: {}",
366 kind,
367 url
368 );
369 }
370 }
371
372 #[test]
373 fn test_retry_delay_exponential_backoff_attempt_0() {
374 let err = ClientError::ServiceUnavailable("down".to_string());
375 let delay = retry_delay(&err, 0);
376 assert_eq!(delay, Duration::from_secs(1));
377 }
378
379 #[test]
380 fn test_retry_delay_exponential_backoff_attempt_1() {
381 let err = ClientError::ServiceUnavailable("down".to_string());
382 let delay = retry_delay(&err, 1);
383 assert_eq!(delay, Duration::from_secs(2));
384 }
385
386 #[test]
387 fn test_retry_delay_exponential_backoff_attempt_2() {
388 let err = ClientError::ServiceUnavailable("down".to_string());
389 let delay = retry_delay(&err, 2);
390 assert_eq!(delay, Duration::from_secs(4));
391 }
392
393 #[test]
394 fn test_retry_delay_rate_limited_uses_retry_after() {
395 let err = ClientError::RateLimited { retry_after: 30 };
396 assert_eq!(retry_delay(&err, 0), Duration::from_secs(30));
398 assert_eq!(retry_delay(&err, 1), Duration::from_secs(30));
399 assert_eq!(retry_delay(&err, 2), Duration::from_secs(30));
400 }
401
402 #[test]
403 fn test_retry_delay_rate_limited_default_retry_after() {
404 let err = ClientError::RateLimited { retry_after: 60 };
405 let delay = retry_delay(&err, 0);
406 assert_eq!(delay, Duration::from_secs(60));
407 }
408
409 #[test]
410 fn test_retry_constants() {
411 assert_eq!(MAX_RETRIES, 3);
412 assert_eq!(INITIAL_BACKOFF_SECS, 1);
413 }
414
415 #[test]
416 fn test_retry_delay_backoff_sequence() {
417 let err = ClientError::ServiceUnavailable("temporarily unavailable".to_string());
418 let delays: Vec<Duration> = (0..MAX_RETRIES).map(|i| retry_delay(&err, i)).collect();
419 assert_eq!(
420 delays,
421 vec![
422 Duration::from_secs(1),
423 Duration::from_secs(2),
424 Duration::from_secs(4),
425 ]
426 );
427 }
428
429 #[test]
430 fn test_non_retryable_error_still_computes_delay() {
431 let err = ClientError::Api {
434 status: 400,
435 message: "bad request".to_string(),
436 };
437 let delay = retry_delay(&err, 0);
438 assert_eq!(delay, Duration::from_secs(1));
439 }
440}