Skip to main content

hoist_client/
client.rs

1//! Azure Search REST API client
2
3use reqwest::{Client, Method, StatusCode};
4use serde_json::Value;
5use tracing::{debug, instrument};
6
7use hoist_core::resources::ResourceKind;
8use hoist_core::Config;
9
10use crate::auth::{get_auth_provider, AuthProvider};
11use crate::error::ClientError;
12
13/// Azure Search API client
14pub struct AzureSearchClient {
15    http: Client,
16    auth: Box<dyn AuthProvider>,
17    base_url: String,
18    api_version: String,
19    preview_api_version: String,
20}
21
22impl AzureSearchClient {
23    /// Create a new client from configuration
24    pub fn new(config: &Config) -> Result<Self, ClientError> {
25        let auth = get_auth_provider()?;
26        let http = Client::builder()
27            .timeout(std::time::Duration::from_secs(30))
28            .build()?;
29
30        Ok(Self {
31            http,
32            auth,
33            base_url: config.service_url(),
34            api_version: config.service.api_version.clone(),
35            preview_api_version: config.service.preview_api_version.clone(),
36        })
37    }
38
39    /// Create a client pointing to a different server, using the same auth and API versions
40    pub fn new_for_server(config: &Config, server_name: &str) -> Result<Self, ClientError> {
41        let auth = get_auth_provider()?;
42        let http = Client::builder()
43            .timeout(std::time::Duration::from_secs(30))
44            .build()?;
45
46        Ok(Self {
47            http,
48            auth,
49            base_url: format!("https://{}.search.windows.net", server_name),
50            api_version: config.service.api_version.clone(),
51            preview_api_version: config.service.preview_api_version.clone(),
52        })
53    }
54
55    /// Create with a custom auth provider (for testing)
56    pub fn with_auth(
57        base_url: String,
58        api_version: String,
59        preview_api_version: String,
60        auth: Box<dyn AuthProvider>,
61    ) -> Result<Self, ClientError> {
62        let http = Client::builder()
63            .timeout(std::time::Duration::from_secs(30))
64            .build()?;
65
66        Ok(Self {
67            http,
68            auth,
69            base_url,
70            api_version,
71            preview_api_version,
72        })
73    }
74
75    /// Get the API version to use for a resource kind
76    fn api_version_for(&self, kind: ResourceKind) -> &str {
77        if kind.is_preview() {
78            &self.preview_api_version
79        } else {
80            &self.api_version
81        }
82    }
83
84    /// Build URL for a resource collection
85    fn collection_url(&self, kind: ResourceKind) -> String {
86        format!(
87            "{}/{}?api-version={}",
88            self.base_url,
89            kind.api_path(),
90            self.api_version_for(kind)
91        )
92    }
93
94    /// Build URL for a specific resource
95    fn resource_url(&self, kind: ResourceKind, name: &str) -> String {
96        format!(
97            "{}/{}/{}?api-version={}",
98            self.base_url,
99            kind.api_path(),
100            name,
101            self.api_version_for(kind)
102        )
103    }
104
105    /// Execute an HTTP request
106    async fn request(
107        &self,
108        method: Method,
109        url: &str,
110        body: Option<&Value>,
111    ) -> Result<Option<Value>, ClientError> {
112        let token = self.auth.get_token()?;
113
114        let mut request = self
115            .http
116            .request(method.clone(), url)
117            .header("Authorization", format!("Bearer {}", token))
118            .header("Content-Type", "application/json");
119
120        if let Some(json) = body {
121            request = request.json(json);
122        }
123
124        debug!("Request: {} {}", method, url);
125        let response = request.send().await?;
126        let status = response.status();
127
128        if status == StatusCode::NO_CONTENT {
129            return Ok(None);
130        }
131
132        let body = response.text().await?;
133
134        if status.is_success() {
135            if body.is_empty() {
136                Ok(None)
137            } else {
138                let value: Value = serde_json::from_str(&body)?;
139                Ok(Some(value))
140            }
141        } else {
142            match status {
143                StatusCode::NOT_FOUND => Err(ClientError::NotFound {
144                    kind: "resource".to_string(),
145                    name: url.to_string(),
146                }),
147                StatusCode::CONFLICT => Err(ClientError::AlreadyExists {
148                    kind: "resource".to_string(),
149                    name: url.to_string(),
150                }),
151                StatusCode::TOO_MANY_REQUESTS => {
152                    let retry_after = 60; // Default retry time
153                    Err(ClientError::RateLimited { retry_after })
154                }
155                StatusCode::SERVICE_UNAVAILABLE => Err(ClientError::ServiceUnavailable(body)),
156                _ => Err(ClientError::from_response(status.as_u16(), &body)),
157            }
158        }
159    }
160
161    /// List all resources of a given kind
162    #[instrument(skip(self))]
163    pub async fn list(&self, kind: ResourceKind) -> Result<Vec<Value>, ClientError> {
164        let url = self.collection_url(kind);
165        let response = self.request(Method::GET, &url, None).await?;
166
167        match response {
168            Some(value) => {
169                // Azure returns { "value": [...] }
170                let items = value
171                    .get("value")
172                    .and_then(|v| v.as_array())
173                    .cloned()
174                    .unwrap_or_default();
175                Ok(items)
176            }
177            None => Ok(Vec::new()),
178        }
179    }
180
181    /// Get a specific resource
182    #[instrument(skip(self))]
183    pub async fn get(&self, kind: ResourceKind, name: &str) -> Result<Value, ClientError> {
184        let url = self.resource_url(kind, name);
185        let response = self.request(Method::GET, &url, None).await?;
186
187        response.ok_or_else(|| ClientError::NotFound {
188            kind: kind.display_name().to_string(),
189            name: name.to_string(),
190        })
191    }
192
193    /// Create or update a resource
194    ///
195    /// Returns the response body if the API returns one. Some APIs (especially
196    /// preview endpoints like Knowledge Sources) return 204 No Content on
197    /// successful update, which yields `Ok(None)`.
198    #[instrument(skip(self, definition))]
199    pub async fn create_or_update(
200        &self,
201        kind: ResourceKind,
202        name: &str,
203        definition: &Value,
204    ) -> Result<Option<Value>, ClientError> {
205        let url = self.resource_url(kind, name);
206        self.request(Method::PUT, &url, Some(definition)).await
207    }
208
209    /// Delete a resource
210    #[instrument(skip(self))]
211    pub async fn delete(&self, kind: ResourceKind, name: &str) -> Result<(), ClientError> {
212        let url = self.resource_url(kind, name);
213        self.request(Method::DELETE, &url, None).await?;
214        Ok(())
215    }
216
217    /// Check if a resource exists
218    pub async fn exists(&self, kind: ResourceKind, name: &str) -> Result<bool, ClientError> {
219        match self.get(kind, name).await {
220            Ok(_) => Ok(true),
221            Err(ClientError::NotFound { .. }) => Ok(false),
222            Err(e) => Err(e),
223        }
224    }
225
226    /// Get the authentication method being used
227    pub fn auth_method(&self) -> &'static str {
228        self.auth.method_name()
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::auth::{AuthError, AuthProvider};
236
237    struct FakeAuth;
238    impl AuthProvider for FakeAuth {
239        fn get_token(&self) -> Result<String, AuthError> {
240            Ok("fake-token".to_string())
241        }
242        fn method_name(&self) -> &'static str {
243            "Fake"
244        }
245    }
246
247    fn make_client() -> AzureSearchClient {
248        AzureSearchClient::with_auth(
249            "https://test-svc.search.windows.net".to_string(),
250            "2024-07-01".to_string(),
251            "2025-11-01-preview".to_string(),
252            Box::new(FakeAuth),
253        )
254        .unwrap()
255    }
256
257    #[test]
258    fn test_collection_url_stable_resource() {
259        let client = make_client();
260        let url = client.collection_url(ResourceKind::Index);
261        assert_eq!(
262            url,
263            "https://test-svc.search.windows.net/indexes?api-version=2024-07-01"
264        );
265    }
266
267    #[test]
268    fn test_collection_url_preview_resource_uses_preview_version() {
269        let client = make_client();
270        let url = client.collection_url(ResourceKind::KnowledgeBase);
271        assert_eq!(
272            url,
273            "https://test-svc.search.windows.net/knowledgebases?api-version=2025-11-01-preview"
274        );
275    }
276
277    #[test]
278    fn test_collection_url_knowledge_source_uses_preview_version() {
279        let client = make_client();
280        let url = client.collection_url(ResourceKind::KnowledgeSource);
281        assert_eq!(
282            url,
283            "https://test-svc.search.windows.net/knowledgesources?api-version=2025-11-01-preview"
284        );
285    }
286
287    #[test]
288    fn test_resource_url_stable() {
289        let client = make_client();
290        let url = client.resource_url(ResourceKind::Index, "my-index");
291        assert_eq!(
292            url,
293            "https://test-svc.search.windows.net/indexes/my-index?api-version=2024-07-01"
294        );
295    }
296
297    #[test]
298    fn test_resource_url_preview() {
299        let client = make_client();
300        let url = client.resource_url(ResourceKind::KnowledgeBase, "my-kb");
301        assert_eq!(
302            url,
303            "https://test-svc.search.windows.net/knowledgebases/my-kb?api-version=2025-11-01-preview"
304        );
305    }
306
307    #[test]
308    fn test_all_stable_kinds_use_stable_version() {
309        let client = make_client();
310        for kind in ResourceKind::stable() {
311            let url = client.collection_url(*kind);
312            assert!(
313                url.contains("2024-07-01"),
314                "{:?} should use stable API version, got: {}",
315                kind,
316                url
317            );
318        }
319    }
320
321    #[test]
322    fn test_new_for_server_produces_correct_base_url() {
323        // We can't easily test new_for_server directly since it calls get_auth_provider,
324        // but we can verify the URL format through with_auth
325        let client = AzureSearchClient::with_auth(
326            "https://other-svc.search.windows.net".to_string(),
327            "2024-07-01".to_string(),
328            "2025-11-01-preview".to_string(),
329            Box::new(FakeAuth),
330        )
331        .unwrap();
332        let url = client.collection_url(ResourceKind::Index);
333        assert_eq!(
334            url,
335            "https://other-svc.search.windows.net/indexes?api-version=2024-07-01"
336        );
337    }
338
339    #[test]
340    fn test_all_preview_kinds_use_preview_version() {
341        let client = make_client();
342        for kind in ResourceKind::all() {
343            if kind.is_preview() {
344                let url = client.collection_url(*kind);
345                assert!(
346                    url.contains("2025-11-01-preview"),
347                    "{:?} should use preview API version, got: {}",
348                    kind,
349                    url
350                );
351            }
352        }
353    }
354}