1use serde::de::DeserializeOwned;
2use tracing::instrument;
3
4use busbar_sf_client::security::{soql, url as url_security};
5
6use crate::error::{Error, ErrorKind, Result};
7
8impl super::SalesforceRestClient {
9 #[instrument(skip(self))]
11 pub async fn get_blob(&self, sobject: &str, id: &str, blob_field: &str) -> Result<Vec<u8>> {
12 if !soql::is_safe_sobject_name(sobject) {
13 return Err(Error::new(ErrorKind::Salesforce {
14 error_code: "INVALID_SOBJECT".to_string(),
15 message: "Invalid SObject name".to_string(),
16 }));
17 }
18 if !url_security::is_valid_salesforce_id(id) {
19 return Err(Error::new(ErrorKind::Salesforce {
20 error_code: "INVALID_ID".to_string(),
21 message: "Invalid Salesforce ID format".to_string(),
22 }));
23 }
24 if !soql::is_safe_field_name(blob_field) {
25 return Err(Error::new(ErrorKind::Salesforce {
26 error_code: "INVALID_FIELD".to_string(),
27 message: "Invalid field name".to_string(),
28 }));
29 }
30 let path = format!("sobjects/{}/{}/{}", sobject, id, blob_field);
31 let url = self.client.rest_url(&path);
32 let request = self.client.get(&url);
33 let response = self.client.execute(request).await?;
34 let bytes = response.bytes().await?;
35 Ok(bytes.to_vec())
36 }
37
38 #[instrument(skip(self))]
40 pub async fn get_rich_text_image(
41 &self,
42 sobject: &str,
43 id: &str,
44 field_name: &str,
45 content_reference_id: &str,
46 ) -> Result<Vec<u8>> {
47 if !soql::is_safe_sobject_name(sobject) {
48 return Err(Error::new(ErrorKind::Salesforce {
49 error_code: "INVALID_SOBJECT".to_string(),
50 message: "Invalid SObject name".to_string(),
51 }));
52 }
53 if !url_security::is_valid_salesforce_id(id) {
54 return Err(Error::new(ErrorKind::Salesforce {
55 error_code: "INVALID_ID".to_string(),
56 message: "Invalid Salesforce ID format".to_string(),
57 }));
58 }
59 if !soql::is_safe_field_name(field_name) {
60 return Err(Error::new(ErrorKind::Salesforce {
61 error_code: "INVALID_FIELD".to_string(),
62 message: "Invalid field name".to_string(),
63 }));
64 }
65 if !url_security::is_valid_salesforce_id(content_reference_id) {
66 return Err(Error::new(ErrorKind::Salesforce {
67 error_code: "INVALID_ID".to_string(),
68 message: "Invalid content reference ID format".to_string(),
69 }));
70 }
71 let path = format!(
72 "sobjects/{}/{}/richTextImageFields/{}/{}",
73 sobject, id, field_name, content_reference_id
74 );
75 let url = self.client.rest_url(&path);
76 let request = self.client.get(&url);
77 let response = self.client.execute(request).await?;
78 let bytes = response.bytes().await?;
79 Ok(bytes.to_vec())
80 }
81
82 #[instrument(skip(self))]
84 pub async fn get_relationship<T: DeserializeOwned>(
85 &self,
86 sobject: &str,
87 id: &str,
88 relationship_name: &str,
89 ) -> Result<T> {
90 if !soql::is_safe_sobject_name(sobject) {
91 return Err(Error::new(ErrorKind::Salesforce {
92 error_code: "INVALID_SOBJECT".to_string(),
93 message: "Invalid SObject name".to_string(),
94 }));
95 }
96 if !url_security::is_valid_salesforce_id(id) {
97 return Err(Error::new(ErrorKind::Salesforce {
98 error_code: "INVALID_ID".to_string(),
99 message: "Invalid Salesforce ID format".to_string(),
100 }));
101 }
102 if !soql::is_safe_field_name(relationship_name) {
103 return Err(Error::new(ErrorKind::Salesforce {
104 error_code: "INVALID_FIELD".to_string(),
105 message: "Invalid relationship name".to_string(),
106 }));
107 }
108 let path = format!("sobjects/{}/{}/{}", sobject, id, relationship_name);
109 self.client.rest_get(&path).await.map_err(Into::into)
110 }
111
112 #[instrument(skip(self))]
114 pub async fn get_sobject_basic_info(&self, sobject: &str) -> Result<super::SObjectInfo> {
115 if !soql::is_safe_sobject_name(sobject) {
116 return Err(Error::new(ErrorKind::Salesforce {
117 error_code: "INVALID_SOBJECT".to_string(),
118 message: "Invalid SObject name".to_string(),
119 }));
120 }
121 let path = format!("sobjects/{}", sobject);
122 self.client.rest_get(&path).await.map_err(Into::into)
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::super::SalesforceRestClient;
129
130 #[tokio::test]
131 async fn test_get_blob_invalid_sobject() {
132 let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
133 let result = client
134 .get_blob("Bad'; DROP--", "001xx000003Dgb2AAC", "Body")
135 .await;
136 assert!(result.is_err());
137 assert!(result.unwrap_err().to_string().contains("INVALID_SOBJECT"));
138 }
139
140 #[tokio::test]
141 async fn test_get_blob_invalid_id() {
142 let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
143 let result = client.get_blob("Attachment", "bad-id", "Body").await;
144 assert!(result.is_err());
145 assert!(result.unwrap_err().to_string().contains("INVALID_ID"));
146 }
147
148 #[tokio::test]
149 async fn test_get_blob_invalid_field() {
150 let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
151 let result = client
152 .get_blob("Attachment", "001xx000003Dgb2AAC", "Bad'; DROP--")
153 .await;
154 assert!(result.is_err());
155 assert!(result.unwrap_err().to_string().contains("INVALID_FIELD"));
156 }
157
158 #[tokio::test]
159 async fn test_get_rich_text_image_invalid_sobject() {
160 let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
161 let result = client
162 .get_rich_text_image(
163 "Bad'; DROP--",
164 "001xx000003Dgb2AAC",
165 "RichText__c",
166 "0P0xx000000001XABC",
167 )
168 .await;
169 assert!(result.is_err());
170 assert!(result.unwrap_err().to_string().contains("INVALID_SOBJECT"));
171 }
172
173 #[tokio::test]
174 async fn test_get_relationship_invalid_sobject() {
175 let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
176 let result = client
177 .get_relationship::<serde_json::Value>("Bad'; DROP--", "001xx000003Dgb2AAC", "Contacts")
178 .await;
179 assert!(result.is_err());
180 assert!(result.unwrap_err().to_string().contains("INVALID_SOBJECT"));
181 }
182
183 #[tokio::test]
184 async fn test_get_relationship_invalid_id() {
185 let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
186 let result = client
187 .get_relationship::<serde_json::Value>("Account", "bad-id", "Contacts")
188 .await;
189 assert!(result.is_err());
190 assert!(result.unwrap_err().to_string().contains("INVALID_ID"));
191 }
192
193 #[tokio::test]
194 async fn test_get_relationship_invalid_name() {
195 let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
196 let result = client
197 .get_relationship::<serde_json::Value>("Account", "001xx000003Dgb2AAC", "Bad'; DROP--")
198 .await;
199 assert!(result.is_err());
200 assert!(result.unwrap_err().to_string().contains("INVALID_FIELD"));
201 }
202
203 #[tokio::test]
204 async fn test_get_sobject_basic_info_invalid_sobject() {
205 let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
206 let result = client.get_sobject_basic_info("Bad'; DROP--").await;
207 assert!(result.is_err());
208 assert!(result.unwrap_err().to_string().contains("INVALID_SOBJECT"));
209 }
210
211 #[tokio::test]
212 async fn test_get_blob_wiremock() {
213 use wiremock::matchers::{method, path_regex};
214 use wiremock::{Mock, MockServer, ResponseTemplate};
215
216 let mock_server = MockServer::start().await;
217
218 let binary_content = vec![0x89, 0x50, 0x4E, 0x47]; Mock::given(method("GET"))
221 .and(path_regex(
222 ".*/sobjects/Attachment/001xx000003Dgb2AAC/Body$",
223 ))
224 .respond_with(ResponseTemplate::new(200).set_body_bytes(binary_content.clone()))
225 .mount(&mock_server)
226 .await;
227
228 let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
229 let result = client
230 .get_blob("Attachment", "001xx000003Dgb2AAC", "Body")
231 .await
232 .expect("get_blob should succeed");
233 assert_eq!(result, binary_content);
234 }
235
236 #[tokio::test]
237 async fn test_get_relationship_wiremock() {
238 use wiremock::matchers::{method, path_regex};
239 use wiremock::{Mock, MockServer, ResponseTemplate};
240
241 let mock_server = MockServer::start().await;
242
243 let body = serde_json::json!({
244 "totalSize": 1,
245 "done": true,
246 "records": [{"Id": "003xx000001Svf0AAC", "Name": "John Doe"}]
247 });
248
249 Mock::given(method("GET"))
250 .and(path_regex(
251 ".*/sobjects/Account/001xx000003Dgb2AAC/Contacts$",
252 ))
253 .respond_with(ResponseTemplate::new(200).set_body_json(&body))
254 .mount(&mock_server)
255 .await;
256
257 let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
258 let result: serde_json::Value = client
259 .get_relationship("Account", "001xx000003Dgb2AAC", "Contacts")
260 .await
261 .expect("get_relationship should succeed");
262 assert_eq!(result["totalSize"], 1);
263 }
264
265 #[tokio::test]
266 async fn test_get_sobject_basic_info_wiremock() {
267 use wiremock::matchers::{method, path_regex};
268 use wiremock::{Mock, MockServer, ResponseTemplate};
269
270 let mock_server = MockServer::start().await;
271
272 let body = serde_json::json!({
273 "objectDescribe": {
274 "name": "Account",
275 "label": "Account",
276 "keyPrefix": "001",
277 "urls": {"sobject": "/services/data/v62.0/sobjects/Account"},
278 "custom": false,
279 "createable": true,
280 "updateable": true,
281 "deletable": true,
282 "queryable": true,
283 "searchable": true
284 },
285 "recentItems": []
286 });
287
288 Mock::given(method("GET"))
289 .and(path_regex(".*/sobjects/Account$"))
290 .respond_with(ResponseTemplate::new(200).set_body_json(&body))
291 .mount(&mock_server)
292 .await;
293
294 let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
295 let result = client
296 .get_sobject_basic_info("Account")
297 .await
298 .expect("get_sobject_basic_info should succeed");
299 assert_eq!(result.object_describe.name, "Account");
300 assert!(result.object_describe.createable);
301 }
302}