Skip to main content

busbar_sf_rest/client/
binary.rs

1use serde::de::DeserializeOwned;
2use tracing::instrument;
3
4use busbar_sf_client::security::{soql, url as url_security};
5
6use crate::error::{Error, ErrorKind, Result};
7
8impl super::SalesforceRestClient {
9    /// Get binary blob content from an SObject field (e.g., Attachment body, Document body).
10    #[instrument(skip(self))]
11    pub async fn get_blob(&self, sobject: &str, id: &str, blob_field: &str) -> Result<Vec<u8>> {
12        if !soql::is_safe_sobject_name(sobject) {
13            return Err(Error::new(ErrorKind::Salesforce {
14                error_code: "INVALID_SOBJECT".to_string(),
15                message: "Invalid SObject name".to_string(),
16            }));
17        }
18        if !url_security::is_valid_salesforce_id(id) {
19            return Err(Error::new(ErrorKind::Salesforce {
20                error_code: "INVALID_ID".to_string(),
21                message: "Invalid Salesforce ID format".to_string(),
22            }));
23        }
24        if !soql::is_safe_field_name(blob_field) {
25            return Err(Error::new(ErrorKind::Salesforce {
26                error_code: "INVALID_FIELD".to_string(),
27                message: "Invalid field name".to_string(),
28            }));
29        }
30        let path = format!("sobjects/{}/{}/{}", sobject, id, blob_field);
31        let url = self.client.rest_url(&path);
32        let request = self.client.get(&url);
33        let response = self.client.execute(request).await?;
34        let bytes = response.bytes().await?;
35        Ok(bytes.to_vec())
36    }
37
38    /// Get a rich text image from an SObject field.
39    #[instrument(skip(self))]
40    pub async fn get_rich_text_image(
41        &self,
42        sobject: &str,
43        id: &str,
44        field_name: &str,
45        content_reference_id: &str,
46    ) -> Result<Vec<u8>> {
47        if !soql::is_safe_sobject_name(sobject) {
48            return Err(Error::new(ErrorKind::Salesforce {
49                error_code: "INVALID_SOBJECT".to_string(),
50                message: "Invalid SObject name".to_string(),
51            }));
52        }
53        if !url_security::is_valid_salesforce_id(id) {
54            return Err(Error::new(ErrorKind::Salesforce {
55                error_code: "INVALID_ID".to_string(),
56                message: "Invalid Salesforce ID format".to_string(),
57            }));
58        }
59        if !soql::is_safe_field_name(field_name) {
60            return Err(Error::new(ErrorKind::Salesforce {
61                error_code: "INVALID_FIELD".to_string(),
62                message: "Invalid field name".to_string(),
63            }));
64        }
65        if !url_security::is_valid_salesforce_id(content_reference_id) {
66            return Err(Error::new(ErrorKind::Salesforce {
67                error_code: "INVALID_ID".to_string(),
68                message: "Invalid content reference ID format".to_string(),
69            }));
70        }
71        let path = format!(
72            "sobjects/{}/{}/richTextImageFields/{}/{}",
73            sobject, id, field_name, content_reference_id
74        );
75        let url = self.client.rest_url(&path);
76        let request = self.client.get(&url);
77        let response = self.client.execute(request).await?;
78        let bytes = response.bytes().await?;
79        Ok(bytes.to_vec())
80    }
81
82    /// Get related records via a relationship field.
83    #[instrument(skip(self))]
84    pub async fn get_relationship<T: DeserializeOwned>(
85        &self,
86        sobject: &str,
87        id: &str,
88        relationship_name: &str,
89    ) -> Result<T> {
90        if !soql::is_safe_sobject_name(sobject) {
91            return Err(Error::new(ErrorKind::Salesforce {
92                error_code: "INVALID_SOBJECT".to_string(),
93                message: "Invalid SObject name".to_string(),
94            }));
95        }
96        if !url_security::is_valid_salesforce_id(id) {
97            return Err(Error::new(ErrorKind::Salesforce {
98                error_code: "INVALID_ID".to_string(),
99                message: "Invalid Salesforce ID format".to_string(),
100            }));
101        }
102        if !soql::is_safe_field_name(relationship_name) {
103            return Err(Error::new(ErrorKind::Salesforce {
104                error_code: "INVALID_FIELD".to_string(),
105                message: "Invalid relationship name".to_string(),
106            }));
107        }
108        let path = format!("sobjects/{}/{}/{}", sobject, id, relationship_name);
109        self.client.rest_get(&path).await.map_err(Into::into)
110    }
111
112    /// Get basic info about an SObject type (describe + recent items).
113    #[instrument(skip(self))]
114    pub async fn get_sobject_basic_info(&self, sobject: &str) -> Result<super::SObjectInfo> {
115        if !soql::is_safe_sobject_name(sobject) {
116            return Err(Error::new(ErrorKind::Salesforce {
117                error_code: "INVALID_SOBJECT".to_string(),
118                message: "Invalid SObject name".to_string(),
119            }));
120        }
121        let path = format!("sobjects/{}", sobject);
122        self.client.rest_get(&path).await.map_err(Into::into)
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::super::SalesforceRestClient;
129
130    #[tokio::test]
131    async fn test_get_blob_invalid_sobject() {
132        let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
133        let result = client
134            .get_blob("Bad'; DROP--", "001xx000003Dgb2AAC", "Body")
135            .await;
136        assert!(result.is_err());
137        assert!(result.unwrap_err().to_string().contains("INVALID_SOBJECT"));
138    }
139
140    #[tokio::test]
141    async fn test_get_blob_invalid_id() {
142        let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
143        let result = client.get_blob("Attachment", "bad-id", "Body").await;
144        assert!(result.is_err());
145        assert!(result.unwrap_err().to_string().contains("INVALID_ID"));
146    }
147
148    #[tokio::test]
149    async fn test_get_blob_invalid_field() {
150        let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
151        let result = client
152            .get_blob("Attachment", "001xx000003Dgb2AAC", "Bad'; DROP--")
153            .await;
154        assert!(result.is_err());
155        assert!(result.unwrap_err().to_string().contains("INVALID_FIELD"));
156    }
157
158    #[tokio::test]
159    async fn test_get_rich_text_image_invalid_sobject() {
160        let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
161        let result = client
162            .get_rich_text_image(
163                "Bad'; DROP--",
164                "001xx000003Dgb2AAC",
165                "RichText__c",
166                "0P0xx000000001XABC",
167            )
168            .await;
169        assert!(result.is_err());
170        assert!(result.unwrap_err().to_string().contains("INVALID_SOBJECT"));
171    }
172
173    #[tokio::test]
174    async fn test_get_relationship_invalid_sobject() {
175        let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
176        let result = client
177            .get_relationship::<serde_json::Value>("Bad'; DROP--", "001xx000003Dgb2AAC", "Contacts")
178            .await;
179        assert!(result.is_err());
180        assert!(result.unwrap_err().to_string().contains("INVALID_SOBJECT"));
181    }
182
183    #[tokio::test]
184    async fn test_get_relationship_invalid_id() {
185        let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
186        let result = client
187            .get_relationship::<serde_json::Value>("Account", "bad-id", "Contacts")
188            .await;
189        assert!(result.is_err());
190        assert!(result.unwrap_err().to_string().contains("INVALID_ID"));
191    }
192
193    #[tokio::test]
194    async fn test_get_relationship_invalid_name() {
195        let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
196        let result = client
197            .get_relationship::<serde_json::Value>("Account", "001xx000003Dgb2AAC", "Bad'; DROP--")
198            .await;
199        assert!(result.is_err());
200        assert!(result.unwrap_err().to_string().contains("INVALID_FIELD"));
201    }
202
203    #[tokio::test]
204    async fn test_get_sobject_basic_info_invalid_sobject() {
205        let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
206        let result = client.get_sobject_basic_info("Bad'; DROP--").await;
207        assert!(result.is_err());
208        assert!(result.unwrap_err().to_string().contains("INVALID_SOBJECT"));
209    }
210
211    #[tokio::test]
212    async fn test_get_blob_wiremock() {
213        use wiremock::matchers::{method, path_regex};
214        use wiremock::{Mock, MockServer, ResponseTemplate};
215
216        let mock_server = MockServer::start().await;
217
218        let binary_content = vec![0x89, 0x50, 0x4E, 0x47]; // PNG magic bytes
219
220        Mock::given(method("GET"))
221            .and(path_regex(
222                ".*/sobjects/Attachment/001xx000003Dgb2AAC/Body$",
223            ))
224            .respond_with(ResponseTemplate::new(200).set_body_bytes(binary_content.clone()))
225            .mount(&mock_server)
226            .await;
227
228        let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
229        let result = client
230            .get_blob("Attachment", "001xx000003Dgb2AAC", "Body")
231            .await
232            .expect("get_blob should succeed");
233        assert_eq!(result, binary_content);
234    }
235
236    #[tokio::test]
237    async fn test_get_relationship_wiremock() {
238        use wiremock::matchers::{method, path_regex};
239        use wiremock::{Mock, MockServer, ResponseTemplate};
240
241        let mock_server = MockServer::start().await;
242
243        let body = serde_json::json!({
244            "totalSize": 1,
245            "done": true,
246            "records": [{"Id": "003xx000001Svf0AAC", "Name": "John Doe"}]
247        });
248
249        Mock::given(method("GET"))
250            .and(path_regex(
251                ".*/sobjects/Account/001xx000003Dgb2AAC/Contacts$",
252            ))
253            .respond_with(ResponseTemplate::new(200).set_body_json(&body))
254            .mount(&mock_server)
255            .await;
256
257        let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
258        let result: serde_json::Value = client
259            .get_relationship("Account", "001xx000003Dgb2AAC", "Contacts")
260            .await
261            .expect("get_relationship should succeed");
262        assert_eq!(result["totalSize"], 1);
263    }
264
265    #[tokio::test]
266    async fn test_get_sobject_basic_info_wiremock() {
267        use wiremock::matchers::{method, path_regex};
268        use wiremock::{Mock, MockServer, ResponseTemplate};
269
270        let mock_server = MockServer::start().await;
271
272        let body = serde_json::json!({
273            "objectDescribe": {
274                "name": "Account",
275                "label": "Account",
276                "keyPrefix": "001",
277                "urls": {"sobject": "/services/data/v62.0/sobjects/Account"},
278                "custom": false,
279                "createable": true,
280                "updateable": true,
281                "deletable": true,
282                "queryable": true,
283                "searchable": true
284            },
285            "recentItems": []
286        });
287
288        Mock::given(method("GET"))
289            .and(path_regex(".*/sobjects/Account$"))
290            .respond_with(ResponseTemplate::new(200).set_body_json(&body))
291            .mount(&mock_server)
292            .await;
293
294        let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
295        let result = client
296            .get_sobject_basic_info("Account")
297            .await
298            .expect("get_sobject_basic_info should succeed");
299        assert_eq!(result.object_describe.name, "Account");
300        assert!(result.object_describe.createable);
301    }
302}