1use reqwest::{Client, Method, StatusCode};
4use serde_json::Value;
5use tracing::{debug, instrument};
6
7use hoist_core::resources::ResourceKind;
8use hoist_core::Config;
9
10use crate::auth::{get_auth_provider, AuthProvider};
11use crate::error::ClientError;
12
13pub struct AzureSearchClient {
15 http: Client,
16 auth: Box<dyn AuthProvider>,
17 base_url: String,
18 api_version: String,
19 preview_api_version: String,
20}
21
22impl AzureSearchClient {
23 pub fn new(config: &Config) -> Result<Self, ClientError> {
25 let auth = get_auth_provider()?;
26 let http = Client::builder()
27 .timeout(std::time::Duration::from_secs(30))
28 .build()?;
29
30 Ok(Self {
31 http,
32 auth,
33 base_url: config.service_url(),
34 api_version: config.service.api_version.clone(),
35 preview_api_version: config.service.preview_api_version.clone(),
36 })
37 }
38
39 pub fn new_for_server(config: &Config, server_name: &str) -> Result<Self, ClientError> {
41 let auth = get_auth_provider()?;
42 let http = Client::builder()
43 .timeout(std::time::Duration::from_secs(30))
44 .build()?;
45
46 Ok(Self {
47 http,
48 auth,
49 base_url: format!("https://{}.search.windows.net", server_name),
50 api_version: config.service.api_version.clone(),
51 preview_api_version: config.service.preview_api_version.clone(),
52 })
53 }
54
55 pub fn with_auth(
57 base_url: String,
58 api_version: String,
59 preview_api_version: String,
60 auth: Box<dyn AuthProvider>,
61 ) -> Result<Self, ClientError> {
62 let http = Client::builder()
63 .timeout(std::time::Duration::from_secs(30))
64 .build()?;
65
66 Ok(Self {
67 http,
68 auth,
69 base_url,
70 api_version,
71 preview_api_version,
72 })
73 }
74
75 fn api_version_for(&self, kind: ResourceKind) -> &str {
77 if kind.is_preview() {
78 &self.preview_api_version
79 } else {
80 &self.api_version
81 }
82 }
83
84 fn collection_url(&self, kind: ResourceKind) -> String {
86 format!(
87 "{}/{}?api-version={}",
88 self.base_url,
89 kind.api_path(),
90 self.api_version_for(kind)
91 )
92 }
93
94 fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
96 format!(
97 "{}/{}/{}?api-version={}",
98 self.base_url,
99 kind.api_path(),
100 name,
101 self.api_version_for(kind)
102 )
103 }
104
105 async fn request(
107 &self,
108 method: Method,
109 url: &str,
110 body: Option<&Value>,
111 ) -> Result<Option<Value>, ClientError> {
112 let token = self.auth.get_token()?;
113
114 let mut request = self
115 .http
116 .request(method.clone(), url)
117 .header("Authorization", format!("Bearer {}", token))
118 .header("Content-Type", "application/json");
119
120 if let Some(json) = body {
121 request = request.json(json);
122 }
123
124 debug!("Request: {} {}", method, url);
125 let response = request.send().await?;
126 let status = response.status();
127
128 if status == StatusCode::NO_CONTENT {
129 return Ok(None);
130 }
131
132 let body = response.text().await?;
133
134 if status.is_success() {
135 if body.is_empty() {
136 Ok(None)
137 } else {
138 let value: Value = serde_json::from_str(&body)?;
139 Ok(Some(value))
140 }
141 } else {
142 match status {
143 StatusCode::NOT_FOUND => Err(ClientError::NotFound {
144 kind: "resource".to_string(),
145 name: url.to_string(),
146 }),
147 StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
148 kind: "resource".to_string(),
149 name: url.to_string(),
150 }),
151 StatusCode::TOO_MANY_REQUESTS => {
152 let retry_after = 60; Err(ClientError::RateLimited { retry_after })
154 }
155 StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
156 _ => Err(ClientError::from_response(status.as_u16(), &body)),
157 }
158 }
159 }
160
161 #[instrument(skip(self))]
163 pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
164 let url = self.collection_url(kind);
165 let response = self.request(Method::GET, &url, None).await?;
166
167 match response {
168 Some(value) => {
169 let items = value
171 .get("value")
172 .and_then(|v| v.as_array())
173 .cloned()
174 .unwrap_or_default();
175 Ok(items)
176 }
177 None => Ok(Vec::new()),
178 }
179 }
180
181 #[instrument(skip(self))]
183 pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
184 let url = self.resource_url(kind, name);
185 let response = self.request(Method::GET, &url, None).await?;
186
187 response.ok_or_else(|| ClientError::NotFound {
188 kind: kind.display_name().to_string(),
189 name: name.to_string(),
190 })
191 }
192
193 #[instrument(skip(self, definition))]
199 pub async fn create_or_update(
200 &self,
201 kind: ResourceKind,
202 name: &str,
203 definition: &Value,
204 ) -> Result<Option<Value>, ClientError> {
205 let url = self.resource_url(kind, name);
206 self.request(Method::PUT, &url, Some(definition)).await
207 }
208
209 #[instrument(skip(self))]
211 pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
212 let url = self.resource_url(kind, name);
213 self.request(Method::DELETE, &url, None).await?;
214 Ok(())
215 }
216
217 pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
219 match self.get(kind, name).await {
220 Ok(_) => Ok(true),
221 Err(ClientError::NotFound { .. }) => Ok(false),
222 Err(e) => Err(e),
223 }
224 }
225
226 pub fn auth_method(&self) -> &'static str {
228 self.auth.method_name()
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::auth::{AuthError, AuthProvider};
236
237 struct FakeAuth;
238 impl AuthProvider for FakeAuth {
239 fn get_token(&self) -> Result<String, AuthError> {
240 Ok("fake-token".to_string())
241 }
242 fn method_name(&self) -> &'static str {
243 "Fake"
244 }
245 }
246
247 fn make_client() -> AzureSearchClient {
248 AzureSearchClient::with_auth(
249 "https://test-svc.search.windows.net".to_string(),
250 "2024-07-01".to_string(),
251 "2025-11-01-preview".to_string(),
252 Box::new(FakeAuth),
253 )
254 .unwrap()
255 }
256
257 #[test]
258 fn test_collection_url_stable_resource() {
259 let client = make_client();
260 let url = client.collection_url(ResourceKind::Index);
261 assert_eq!(
262 url,
263 "https://test-svc.search.windows.net/indexes?api-version=2024-07-01"
264 );
265 }
266
267 #[test]
268 fn test_collection_url_preview_resource_uses_preview_version() {
269 let client = make_client();
270 let url = client.collection_url(ResourceKind::KnowledgeBase);
271 assert_eq!(
272 url,
273 "https://test-svc.search.windows.net/knowledgebases?api-version=2025-11-01-preview"
274 );
275 }
276
277 #[test]
278 fn test_collection_url_knowledge_source_uses_preview_version() {
279 let client = make_client();
280 let url = client.collection_url(ResourceKind::KnowledgeSource);
281 assert_eq!(
282 url,
283 "https://test-svc.search.windows.net/knowledgesources?api-version=2025-11-01-preview"
284 );
285 }
286
287 #[test]
288 fn test_resource_url_stable() {
289 let client = make_client();
290 let url = client.resource_url(ResourceKind::Index, "my-index");
291 assert_eq!(
292 url,
293 "https://test-svc.search.windows.net/indexes/my-index?api-version=2024-07-01"
294 );
295 }
296
297 #[test]
298 fn test_resource_url_preview() {
299 let client = make_client();
300 let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
301 assert_eq!(
302 url,
303 "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-11-01-preview"
304 );
305 }
306
307 #[test]
308 fn test_all_stable_kinds_use_stable_version() {
309 let client = make_client();
310 for kind in ResourceKind::stable() {
311 let url = client.collection_url(*kind);
312 assert!(
313 url.contains("2024-07-01"),
314 "{:?} should use stable API version, got: {}",
315 kind,
316 url
317 );
318 }
319 }
320
321 #[test]
322 fn test_new_for_server_produces_correct_base_url() {
323 let client = AzureSearchClient::with_auth(
326 "https://other-svc.search.windows.net".to_string(),
327 "2024-07-01".to_string(),
328 "2025-11-01-preview".to_string(),
329 Box::new(FakeAuth),
330 )
331 .unwrap();
332 let url = client.collection_url(ResourceKind::Index);
333 assert_eq!(
334 url,
335 "https://other-svc.search.windows.net/indexes?api-version=2024-07-01"
336 );
337 }
338
339 #[test]
340 fn test_all_preview_kinds_use_preview_version() {
341 let client = make_client();
342 for kind in ResourceKind::all() {
343 if kind.is_preview() {
344 let url = client.collection_url(*kind);
345 assert!(
346 url.contains("2025-11-01-preview"),
347 "{:?} should use preview API version, got: {}",
348 kind,
349 url
350 );
351 }
352 }
353 }
354}