Skip to main content

busbar_sf_rest/client/
crud.rs

1use serde::{de::DeserializeOwned, Serialize};
2use tracing::instrument;
3
4use busbar_sf_client::security::{soql, url as url_security};
5
6use crate::error::{Error, ErrorKind, Result};
7use crate::sobject::{CreateResult, UpsertResult};
8
9impl super::SalesforceRestClient {
10    /// Create a new record.
11    ///
12    /// Returns the ID of the created record.
13    #[instrument(skip(self, record))]
14    pub async fn create<T: Serialize>(&self, sobject: &str, record: &T) -> Result<String> {
15        if !soql::is_safe_sobject_name(sobject) {
16            return Err(Error::new(ErrorKind::Salesforce {
17                error_code: "INVALID_SOBJECT".to_string(),
18                message: "Invalid SObject name".to_string(),
19            }));
20        }
21        let path = format!("sobjects/{}", sobject);
22        let result: CreateResult = self.client.rest_post(&path, record).await?;
23
24        if result.success {
25            Ok(result.id)
26        } else {
27            let errors: Vec<String> = result.errors.iter().map(|e| e.message.clone()).collect();
28            Err(Error::new(ErrorKind::Salesforce {
29                error_code: "CREATE_FAILED".to_string(),
30                message: errors.join("; "),
31            }))
32        }
33    }
34
35    /// Get a record by ID.
36    ///
37    /// Optionally specify which fields to retrieve.
38    #[instrument(skip(self))]
39    pub async fn get<T: DeserializeOwned>(
40        &self,
41        sobject: &str,
42        id: &str,
43        fields: Option<&[&str]>,
44    ) -> Result<T> {
45        if !soql::is_safe_sobject_name(sobject) {
46            return Err(Error::new(ErrorKind::Salesforce {
47                error_code: "INVALID_SOBJECT".to_string(),
48                message: "Invalid SObject name".to_string(),
49            }));
50        }
51        if !url_security::is_valid_salesforce_id(id) {
52            return Err(Error::new(ErrorKind::Salesforce {
53                error_code: "INVALID_ID".to_string(),
54                message: "Invalid Salesforce ID format".to_string(),
55            }));
56        }
57        let path = if let Some(fields) = fields {
58            // Validate and filter field names for safety
59            let safe_fields: Vec<&str> = soql::filter_safe_fields(fields.iter().copied()).collect();
60            if safe_fields.is_empty() {
61                return Err(Error::new(ErrorKind::Salesforce {
62                    error_code: "INVALID_FIELDS".to_string(),
63                    message: "No valid field names provided".to_string(),
64                }));
65            }
66            format!(
67                "sobjects/{}/{}?fields={}",
68                sobject,
69                id,
70                safe_fields.join(",")
71            )
72        } else {
73            format!("sobjects/{}/{}", sobject, id)
74        };
75        self.client.rest_get(&path).await.map_err(Into::into)
76    }
77
78    /// Update a record.
79    #[instrument(skip(self, record))]
80    pub async fn update<T: Serialize>(&self, sobject: &str, id: &str, record: &T) -> Result<()> {
81        if !soql::is_safe_sobject_name(sobject) {
82            return Err(Error::new(ErrorKind::Salesforce {
83                error_code: "INVALID_SOBJECT".to_string(),
84                message: "Invalid SObject name".to_string(),
85            }));
86        }
87        if !url_security::is_valid_salesforce_id(id) {
88            return Err(Error::new(ErrorKind::Salesforce {
89                error_code: "INVALID_ID".to_string(),
90                message: "Invalid Salesforce ID format".to_string(),
91            }));
92        }
93        let path = format!("sobjects/{}/{}", sobject, id);
94        self.client
95            .rest_patch(&path, record)
96            .await
97            .map_err(Into::into)
98    }
99
100    /// Delete a record.
101    #[instrument(skip(self))]
102    pub async fn delete(&self, sobject: &str, id: &str) -> Result<()> {
103        if !soql::is_safe_sobject_name(sobject) {
104            return Err(Error::new(ErrorKind::Salesforce {
105                error_code: "INVALID_SOBJECT".to_string(),
106                message: "Invalid SObject name".to_string(),
107            }));
108        }
109        if !url_security::is_valid_salesforce_id(id) {
110            return Err(Error::new(ErrorKind::Salesforce {
111                error_code: "INVALID_ID".to_string(),
112                message: "Invalid Salesforce ID format".to_string(),
113            }));
114        }
115        let path = format!("sobjects/{}/{}", sobject, id);
116        self.client.rest_delete(&path).await.map_err(Into::into)
117    }
118
119    /// Upsert a record using an external ID field.
120    ///
121    /// Creates the record if it doesn't exist, updates it if it does.
122    #[instrument(skip(self, record))]
123    pub async fn upsert<T: Serialize>(
124        &self,
125        sobject: &str,
126        external_id_field: &str,
127        external_id_value: &str,
128        record: &T,
129    ) -> Result<UpsertResult> {
130        if !soql::is_safe_sobject_name(sobject) {
131            return Err(Error::new(ErrorKind::Salesforce {
132                error_code: "INVALID_SOBJECT".to_string(),
133                message: "Invalid SObject name".to_string(),
134            }));
135        }
136        if !soql::is_safe_field_name(external_id_field) {
137            return Err(Error::new(ErrorKind::Salesforce {
138                error_code: "INVALID_FIELD".to_string(),
139                message: "Invalid external ID field name".to_string(),
140            }));
141        }
142        // URL-encode the external ID value to handle special characters
143        let encoded_value = url_security::encode_param(external_id_value);
144        let path = format!(
145            "sobjects/{}/{}/{}",
146            sobject, external_id_field, encoded_value
147        );
148        let url = self.client.rest_url(&path);
149        let request = self.client.patch(&url).json(record)?;
150        let response = self.client.execute(request).await?;
151
152        // Upsert returns 201 Created, 200 OK (updated), or 204 No Content (updated, older APIs)
153        let status = response.status();
154        if status == 201 || status == 200 {
155            // 201 Created or 200 Updated - response body has the result
156            let result: UpsertResult = response.json().await?;
157            Ok(result)
158        } else if status == 204 {
159            // Updated - no response body
160            Ok(UpsertResult {
161                id: external_id_value.to_string(),
162                success: true,
163                created: false,
164                errors: vec![],
165            })
166        } else {
167            Err(Error::new(ErrorKind::Salesforce {
168                error_code: "UPSERT_FAILED".to_string(),
169                message: format!("Unexpected status: {}", status),
170            }))
171        }
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::super::SalesforceRestClient;
178
179    #[tokio::test]
180    async fn test_upsert_created_201_wiremock() {
181        use wiremock::matchers::{method, path_regex};
182        use wiremock::{Mock, MockServer, ResponseTemplate};
183
184        let mock_server = MockServer::start().await;
185
186        let body = serde_json::json!({
187            "id": "001xx000003DgAAAS",
188            "success": true,
189            "created": true,
190            "errors": []
191        });
192
193        Mock::given(method("PATCH"))
194            .and(path_regex(".*/sobjects/Account/ExtId__c/.*"))
195            .respond_with(ResponseTemplate::new(201).set_body_json(&body))
196            .mount(&mock_server)
197            .await;
198
199        let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
200        let result = client
201            .upsert(
202                "Account",
203                "ExtId__c",
204                "ext-123",
205                &serde_json::json!({"Name": "Test"}),
206            )
207            .await
208            .expect("Upsert 201 should succeed");
209        assert!(result.created);
210        assert_eq!(result.id, "001xx000003DgAAAS");
211    }
212
213    #[tokio::test]
214    async fn test_upsert_updated_200_wiremock() {
215        use wiremock::matchers::{method, path_regex};
216        use wiremock::{Mock, MockServer, ResponseTemplate};
217
218        let mock_server = MockServer::start().await;
219
220        let body = serde_json::json!({
221            "id": "001xx000003DgAAAS",
222            "success": true,
223            "created": false,
224            "errors": []
225        });
226
227        Mock::given(method("PATCH"))
228            .and(path_regex(".*/sobjects/Account/ExtId__c/.*"))
229            .respond_with(ResponseTemplate::new(200).set_body_json(&body))
230            .mount(&mock_server)
231            .await;
232
233        let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
234        let result = client
235            .upsert(
236                "Account",
237                "ExtId__c",
238                "ext-123",
239                &serde_json::json!({"Name": "Updated"}),
240            )
241            .await
242            .expect("Upsert 200 should succeed");
243        assert!(!result.created);
244        assert_eq!(result.id, "001xx000003DgAAAS");
245    }
246
247    #[tokio::test]
248    async fn test_upsert_updated_204_wiremock() {
249        use wiremock::matchers::{method, path_regex};
250        use wiremock::{Mock, MockServer, ResponseTemplate};
251
252        let mock_server = MockServer::start().await;
253
254        Mock::given(method("PATCH"))
255            .and(path_regex(".*/sobjects/Account/ExtId__c/.*"))
256            .respond_with(ResponseTemplate::new(204))
257            .mount(&mock_server)
258            .await;
259
260        let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
261        let result = client
262            .upsert(
263                "Account",
264                "ExtId__c",
265                "ext-123",
266                &serde_json::json!({"Name": "Updated"}),
267            )
268            .await
269            .expect("Upsert 204 should succeed");
270        assert!(!result.created);
271    }
272
273    #[tokio::test]
274    async fn test_upsert_invalid_sobject() {
275        let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
276        let result = client
277            .upsert("Bad'; DROP--", "ExtId__c", "123", &serde_json::json!({}))
278            .await;
279        assert!(result.is_err());
280        assert!(result.unwrap_err().to_string().contains("INVALID_SOBJECT"));
281    }
282
283    #[tokio::test]
284    async fn test_upsert_invalid_field() {
285        let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
286        let result = client
287            .upsert("Account", "Bad'; DROP--", "123", &serde_json::json!({}))
288            .await;
289        assert!(result.is_err());
290        assert!(result.unwrap_err().to_string().contains("INVALID_FIELD"));
291    }
292}