agixt_sdk/client/
providers.rs1use crate::error::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Debug, Serialize, Deserialize)]
6pub struct ProviderResponse {
7 pub providers: Vec<String>,
8}
9
10#[derive(Debug, Serialize, Deserialize)]
11pub struct ProviderSettings {
12 pub settings: HashMap<String, serde_json::Value>,
13}
14
15impl super::AGiXTSDK {
16 pub async fn get_providers(&self) -> Result<Vec<String>> {
18 let response = self
19 .client
20 .get(&format!("{}/api/provider", self.base_uri))
21 .headers(self.headers.lock().await.clone())
22 .send()
23 .await?;
24
25 let status = response.status();
26 let text = response.text().await?;
27
28 if self.verbose {
29 self.parse_response(status, &text).await?;
30 }
31
32 let result: ProviderResponse = serde_json::from_str(&text)?;
33 Ok(result.providers)
34 }
35
36 pub async fn get_providers_by_service(&self, service: &str) -> Result<Vec<String>> {
38 let response = self
39 .client
40 .get(&format!("{}/api/providers/service/{}", self.base_uri, service))
41 .headers(self.headers.lock().await.clone())
42 .send()
43 .await?;
44
45 let status = response.status();
46 let text = response.text().await?;
47
48 if self.verbose {
49 self.parse_response(status, &text).await?;
50 }
51
52 let result: ProviderResponse = serde_json::from_str(&text)?;
53 Ok(result.providers)
54 }
55
56 pub async fn get_provider_settings(&self, provider_name: &str) -> Result<HashMap<String, serde_json::Value>> {
58 let response = self
59 .client
60 .get(&format!("{}/api/provider/{}", self.base_uri, provider_name))
61 .headers(self.headers.lock().await.clone())
62 .send()
63 .await?;
64
65 let status = response.status();
66 let text = response.text().await?;
67
68 if self.verbose {
69 self.parse_response(status, &text).await?;
70 }
71
72 let result: ProviderSettings = serde_json::from_str(&text)?;
73 Ok(result.settings)
74 }
75
76 pub async fn get_embed_providers(&self) -> Result<Vec<String>> {
78 let response = self
79 .client
80 .get(&format!("{}/api/embedding_providers", self.base_uri))
81 .headers(self.headers.lock().await.clone())
82 .send()
83 .await?;
84
85 let status = response.status();
86 let text = response.text().await?;
87
88 if self.verbose {
89 self.parse_response(status, &text).await?;
90 }
91
92 let result: ProviderResponse = serde_json::from_str(&text)?;
93 Ok(result.providers)
94 }
95
96 pub async fn get_embedders(&self) -> Result<HashMap<String, serde_json::Value>> {
98 let response = self
99 .client
100 .get(&format!("{}/api/embedders", self.base_uri))
101 .headers(self.headers.lock().await.clone())
102 .send()
103 .await?;
104
105 let status = response.status();
106 let text = response.text().await?;
107
108 if self.verbose {
109 self.parse_response(status, &text).await?;
110 }
111
112 #[derive(Deserialize)]
113 struct EmbeddersResponse {
114 embedders: HashMap<String, serde_json::Value>,
115 }
116
117 let result: EmbeddersResponse = serde_json::from_str(&text)?;
118 Ok(result.embedders)
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use crate::AGiXTSDK;
125 use mockito;
126
127 #[tokio::test]
128 async fn test_get_providers() {
129 let mut mock_server = mockito::Server::new();
130 let _mock = mock_server
131 .mock("GET", "/api/provider")
132 .with_status(200)
133 .with_header("content-type", "application/json")
134 .with_body(r#"{"providers": ["provider1", "provider2"]}"#)
135 .create();
136
137 let client = AGiXTSDK::new(Some(mock_server.url()), None, false);
138 let providers = client.get_providers().await.unwrap();
139
140 assert_eq!(providers, vec!["provider1", "provider2"]);
141 }
142}