1use std::time::Duration;
4
5use reqwest::{Client, Method, StatusCode};
6use serde_json::Value;
7use tracing::{debug, instrument, warn};
8
9use hoist_core::resources::ResourceKind;
10use hoist_core::Config;
11
12use crate::auth::{get_auth_provider, AuthProvider};
13use crate::error::ClientError;
14
15const MAX_RETRIES: u32 = 3;
17
18const INITIAL_BACKOFF_SECS: u64 = 1;
20
21fn retry_delay(error: &ClientError, attempt: u32) -> Duration {
26 match error {
27 ClientError::RateLimited { retry_after } => Duration::from_secs(*retry_after),
28 _ => Duration::from_secs(INITIAL_BACKOFF_SECS * 2u64.pow(attempt)),
29 }
30}
31
32const AGENTIC_RETRIEVAL_API_VERSION: &str = "2025-08-01-preview";
43
44pub struct AzureSearchClient {
46 http: Client,
47 auth: Box<dyn AuthProvider>,
48 base_url: String,
49 preview_api_version: String,
50}
51
52impl AzureSearchClient {
53 pub fn new(config: &Config) -> Result<Self, ClientError> {
55 let auth = get_auth_provider()?;
56 let http = Client::builder()
57 .timeout(std::time::Duration::from_secs(30))
58 .build()?;
59
60 Ok(Self {
61 http,
62 auth,
63 base_url: config.service_url(),
64 preview_api_version: config.api_version_for(true).to_string(),
65 })
66 }
67
68 pub fn new_for_server(config: &Config, server_name: &str) -> Result<Self, ClientError> {
70 let auth = get_auth_provider()?;
71 let http = Client::builder()
72 .timeout(std::time::Duration::from_secs(30))
73 .build()?;
74
75 Ok(Self {
76 http,
77 auth,
78 base_url: format!("https://{}.search.windows.net", server_name),
79 preview_api_version: config.api_version_for(true).to_string(),
80 })
81 }
82
83 pub fn with_auth(
85 base_url: String,
86 preview_api_version: String,
87 auth: Box<dyn AuthProvider>,
88 ) -> Result<Self, ClientError> {
89 let http = Client::builder()
90 .timeout(std::time::Duration::from_secs(30))
91 .build()?;
92
93 Ok(Self {
94 http,
95 auth,
96 base_url,
97 preview_api_version,
98 })
99 }
100
101 fn api_version_for(&self, kind: ResourceKind) -> &str {
109 match kind {
110 ResourceKind::KnowledgeBase | ResourceKind::KnowledgeSource => {
111 AGENTIC_RETRIEVAL_API_VERSION
112 }
113 _ => &self.preview_api_version,
114 }
115 }
116
117 fn collection_url(&self, kind: ResourceKind) -> String {
119 format!(
120 "{}/{}?api-version={}",
121 self.base_url,
122 kind.api_path(),
123 self.api_version_for(kind)
124 )
125 }
126
127 fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
129 format!(
130 "{}/{}/{}?api-version={}",
131 self.base_url,
132 kind.api_path(),
133 name,
134 self.api_version_for(kind)
135 )
136 }
137
138 async fn request(
140 &self,
141 method: Method,
142 url: &str,
143 body: Option<&Value>,
144 ) -> Result<Option<Value>, ClientError> {
145 let token = self.auth.get_token()?;
146
147 let mut request = self
148 .http
149 .request(method.clone(), url)
150 .header("Authorization", format!("Bearer {}", token))
151 .header("Content-Type", "application/json");
152
153 if let Some(json) = body {
154 request = request.json(json);
155 }
156
157 debug!("Request: {} {}", method, url);
158 let response = request.send().await?;
159 let status = response.status();
160
161 if status == StatusCode::NO_CONTENT {
162 return Ok(None);
163 }
164
165 let body = response.text().await?;
166
167 if status.is_success() {
168 if body.is_empty() {
169 Ok(None)
170 } else {
171 let value: Value = serde_json::from_str(&body)?;
172 Ok(Some(value))
173 }
174 } else {
175 match status {
176 StatusCode::NOT_FOUND => Err(ClientError::NotFound {
177 kind: "resource".to_string(),
178 name: url.to_string(),
179 }),
180 StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
181 kind: "resource".to_string(),
182 name: url.to_string(),
183 }),
184 StatusCode::TOO_MANY_REQUESTS => {
185 let retry_after = 60; Err(ClientError::RateLimited { retry_after })
187 }
188 StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
189 _ => Err(ClientError::from_response_with_url(
190 status.as_u16(),
191 &body,
192 Some(url),
193 )),
194 }
195 }
196 }
197
198 async fn request_with_retry(
204 &self,
205 method: Method,
206 url: &str,
207 body: Option<&Value>,
208 ) -> Result<Option<Value>, ClientError> {
209 let mut attempt = 0u32;
210 loop {
211 match self.request(method.clone(), url, body).await {
212 Ok(value) => return Ok(value),
213 Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
214 let delay = retry_delay(&err, attempt);
215 warn!(
216 "Request {} {} failed (attempt {}/{}): {}. Retrying in {:?}",
217 method,
218 url,
219 attempt + 1,
220 MAX_RETRIES + 1,
221 err,
222 delay,
223 );
224 tokio::time::sleep(delay).await;
225 attempt += 1;
226 }
227 Err(err) => return Err(err),
228 }
229 }
230 }
231
232 #[instrument(skip(self))]
234 pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
235 let url = self.collection_url(kind);
236 let response = self.request_with_retry(Method::GET, &url, None).await?;
237
238 match response {
239 Some(value) => {
240 let items = value
242 .get("value")
243 .and_then(|v| v.as_array())
244 .cloned()
245 .unwrap_or_default();
246 Ok(items)
247 }
248 None => Ok(Vec::new()),
249 }
250 }
251
252 #[instrument(skip(self))]
254 pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
255 let url = self.resource_url(kind, name);
256 let response = self.request_with_retry(Method::GET, &url, None).await?;
257
258 response.ok_or_else(|| ClientError::NotFound {
259 kind: kind.display_name().to_string(),
260 name: name.to_string(),
261 })
262 }
263
264 #[instrument(skip(self, definition))]
270 pub async fn create_or_update(
271 &self,
272 kind: ResourceKind,
273 name: &str,
274 definition: &Value,
275 ) -> Result<Option<Value>, ClientError> {
276 let url = self.resource_url(kind, name);
277 self.request_with_retry(Method::PUT, &url, Some(definition))
278 .await
279 }
280
281 #[instrument(skip(self))]
283 pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
284 let url = self.resource_url(kind, name);
285 self.request_with_retry(Method::DELETE, &url, None).await?;
286 Ok(())
287 }
288
289 pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
291 match self.get(kind, name).await {
292 Ok(_) => Ok(true),
293 Err(ClientError::NotFound { .. }) => Ok(false),
294 Err(e) => Err(e),
295 }
296 }
297
298 pub fn auth_method(&self) -> &'static str {
300 self.auth.method_name()
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::auth::{AuthError, AuthProvider};
308
309 struct FakeAuth;
310 impl AuthProvider for FakeAuth {
311 fn get_token(&self) -> Result<String, AuthError> {
312 Ok("fake-token".to_string())
313 }
314 fn method_name(&self) -> &'static str {
315 "Fake"
316 }
317 }
318
319 fn make_client() -> AzureSearchClient {
320 AzureSearchClient::with_auth(
321 "https://test-svc.search.windows.net".to_string(),
322 "2025-11-01-preview".to_string(),
323 Box::new(FakeAuth),
324 )
325 .unwrap()
326 }
327
328 #[test]
329 fn test_collection_url_uses_preview_version() {
330 let client = make_client();
331 let url = client.collection_url(ResourceKind::Index);
332 assert_eq!(
333 url,
334 "https://test-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
335 );
336 }
337
338 #[test]
339 fn test_collection_url_knowledge_base_uses_agentic_version() {
340 let client = make_client();
341 let url = client.collection_url(ResourceKind::KnowledgeBase);
342 assert_eq!(
343 url,
344 "https://test-svc.search.windows.net/knowledgebases?api-version=2025-08-01-preview"
345 );
346 }
347
348 #[test]
349 fn test_collection_url_knowledge_source_uses_agentic_version() {
350 let client = make_client();
351 let url = client.collection_url(ResourceKind::KnowledgeSource);
352 assert_eq!(
353 url,
354 "https://test-svc.search.windows.net/knowledgesources?api-version=2025-08-01-preview"
355 );
356 }
357
358 #[test]
359 fn test_resource_url_uses_preview_version() {
360 let client = make_client();
361 let url = client.resource_url(ResourceKind::Index, "my-index");
362 assert_eq!(
363 url,
364 "https://test-svc.search.windows.net/indexes/my-index?api-version=2025-11-01-preview"
365 );
366 }
367
368 #[test]
369 fn test_resource_url_knowledge_base_uses_agentic_version() {
370 let client = make_client();
371 let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
372 assert_eq!(
373 url,
374 "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-08-01-preview"
375 );
376 }
377
378 #[test]
379 fn test_new_for_server_produces_correct_base_url() {
380 let client = AzureSearchClient::with_auth(
383 "https://other-svc.search.windows.net".to_string(),
384 "2025-11-01-preview".to_string(),
385 Box::new(FakeAuth),
386 )
387 .unwrap();
388 let url = client.collection_url(ResourceKind::Index);
389 assert_eq!(
390 url,
391 "https://other-svc.search.windows.net/indexes?api-version=2025-11-01-preview"
392 );
393 }
394
395 #[test]
396 fn test_all_kinds_use_expected_api_version() {
397 let client = make_client();
398 for kind in ResourceKind::all() {
399 let url = client.collection_url(*kind);
400 match kind {
401 ResourceKind::KnowledgeBase | ResourceKind::KnowledgeSource => {
402 assert!(
403 url.contains("2025-08-01-preview"),
404 "{:?} should use agentic retrieval API version, got: {}",
405 kind,
406 url
407 );
408 }
409 _ => {
410 assert!(
411 url.contains("2025-11-01-preview"),
412 "{:?} should use preview API version, got: {}",
413 kind,
414 url
415 );
416 }
417 }
418 }
419 }
420
421 #[test]
422 fn test_retry_delay_exponential_backoff_attempt_0() {
423 let err = ClientError::ServiceUnavailable("down".to_string());
424 let delay = retry_delay(&err, 0);
425 assert_eq!(delay, Duration::from_secs(1));
426 }
427
428 #[test]
429 fn test_retry_delay_exponential_backoff_attempt_1() {
430 let err = ClientError::ServiceUnavailable("down".to_string());
431 let delay = retry_delay(&err, 1);
432 assert_eq!(delay, Duration::from_secs(2));
433 }
434
435 #[test]
436 fn test_retry_delay_exponential_backoff_attempt_2() {
437 let err = ClientError::ServiceUnavailable("down".to_string());
438 let delay = retry_delay(&err, 2);
439 assert_eq!(delay, Duration::from_secs(4));
440 }
441
442 #[test]
443 fn test_retry_delay_rate_limited_uses_retry_after() {
444 let err = ClientError::RateLimited { retry_after: 30 };
445 assert_eq!(retry_delay(&err, 0), Duration::from_secs(30));
447 assert_eq!(retry_delay(&err, 1), Duration::from_secs(30));
448 assert_eq!(retry_delay(&err, 2), Duration::from_secs(30));
449 }
450
451 #[test]
452 fn test_retry_delay_rate_limited_default_retry_after() {
453 let err = ClientError::RateLimited { retry_after: 60 };
454 let delay = retry_delay(&err, 0);
455 assert_eq!(delay, Duration::from_secs(60));
456 }
457
458 #[test]
459 fn test_retry_constants() {
460 assert_eq!(MAX_RETRIES, 3);
461 assert_eq!(INITIAL_BACKOFF_SECS, 1);
462 }
463
464 #[test]
465 fn test_retry_delay_backoff_sequence() {
466 let err = ClientError::ServiceUnavailable("temporarily unavailable".to_string());
467 let delays: Vec<Duration> = (0..MAX_RETRIES).map(|i| retry_delay(&err, i)).collect();
468 assert_eq!(
469 delays,
470 vec![
471 Duration::from_secs(1),
472 Duration::from_secs(2),
473 Duration::from_secs(4),
474 ]
475 );
476 }
477
478 #[test]
479 fn test_non_retryable_error_still_computes_delay() {
480 let err = ClientError::Api {
483 status: 400,
484 message: "bad request".to_string(),
485 };
486 let delay = retry_delay(&err, 0);
487 assert_eq!(delay, Duration::from_secs(1));
488 }
489}