agixt_sdk/client/
providers.rs

1use crate::error::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Debug, Serialize, Deserialize)]
6pub struct ProviderResponse {
7    pub providers: Vec<String>,
8}
9
10#[derive(Debug, Serialize, Deserialize)]
11pub struct ProviderSettings {
12    pub settings: HashMap<String, serde_json::Value>,
13}
14
15impl super::AGiXTSDK {
16    /// Get list of available providers
17    pub async fn get_providers(&self) -> Result<Vec<String>> {
18        let response = self
19            .client
20            .get(&format!("{}/api/provider", self.base_uri))
21            .headers(self.headers.lock().await.clone())
22            .send()
23            .await?;
24
25        let status = response.status();
26        let text = response.text().await?;
27
28        if self.verbose {
29            self.parse_response(status, &text).await?;
30        }
31
32        let result: ProviderResponse = serde_json::from_str(&text)?;
33        Ok(result.providers)
34    }
35
36    /// Get providers by service type
37    pub async fn get_providers_by_service(&self, service: &str) -> Result<Vec<String>> {
38        let response = self
39            .client
40            .get(&format!("{}/api/providers/service/{}", self.base_uri, service))
41            .headers(self.headers.lock().await.clone())
42            .send()
43            .await?;
44
45        let status = response.status();
46        let text = response.text().await?;
47
48        if self.verbose {
49            self.parse_response(status, &text).await?;
50        }
51
52        let result: ProviderResponse = serde_json::from_str(&text)?;
53        Ok(result.providers)
54    }
55
56    /// Get settings for a specific provider
57    pub async fn get_provider_settings(&self, provider_name: &str) -> Result<HashMap<String, serde_json::Value>> {
58        let response = self
59            .client
60            .get(&format!("{}/api/provider/{}", self.base_uri, provider_name))
61            .headers(self.headers.lock().await.clone())
62            .send()
63            .await?;
64
65        let status = response.status();
66        let text = response.text().await?;
67
68        if self.verbose {
69            self.parse_response(status, &text).await?;
70        }
71
72        let result: ProviderSettings = serde_json::from_str(&text)?;
73        Ok(result.settings)
74    }
75
76    /// Get list of embedding providers
77    pub async fn get_embed_providers(&self) -> Result<Vec<String>> {
78        let response = self
79            .client
80            .get(&format!("{}/api/embedding_providers", self.base_uri))
81            .headers(self.headers.lock().await.clone())
82            .send()
83            .await?;
84
85        let status = response.status();
86        let text = response.text().await?;
87
88        if self.verbose {
89            self.parse_response(status, &text).await?;
90        }
91
92        let result: ProviderResponse = serde_json::from_str(&text)?;
93        Ok(result.providers)
94    }
95
96    /// Get details of all embedders
97    pub async fn get_embedders(&self) -> Result<HashMap<String, serde_json::Value>> {
98        let response = self
99            .client
100            .get(&format!("{}/api/embedders", self.base_uri))
101            .headers(self.headers.lock().await.clone())
102            .send()
103            .await?;
104
105        let status = response.status();
106        let text = response.text().await?;
107
108        if self.verbose {
109            self.parse_response(status, &text).await?;
110        }
111
112        #[derive(Deserialize)]
113        struct EmbeddersResponse {
114            embedders: HashMap<String, serde_json::Value>,
115        }
116
117        let result: EmbeddersResponse = serde_json::from_str(&text)?;
118        Ok(result.embedders)
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use crate::AGiXTSDK;
125    use mockito;
126
127    #[tokio::test]
128    async fn test_get_providers() {
129        let mut mock_server = mockito::Server::new();
130        let _mock = mock_server
131            .mock("GET", "/api/provider")
132            .with_status(200)
133            .with_header("content-type", "application/json")
134            .with_body(r#"{"providers": ["provider1", "provider2"]}"#)
135            .create();
136
137        let client = AGiXTSDK::new(Some(mock_server.url()), None, false);
138        let providers = client.get_providers().await.unwrap();
139        
140        assert_eq!(providers, vec!["provider1", "provider2"]);
141    }
142}