1use serde::{de::DeserializeOwned, Serialize};
2use tracing::instrument;
3
4use busbar_sf_client::security::{soql, url as url_security};
5
6use crate::error::{Error, ErrorKind, Result};
7use crate::sobject::{CreateResult, UpsertResult};
8
9impl super::SalesforceRestClient {
10 #[instrument(skip(self, record))]
14 pub async fn create<T: Serialize>(&self, sobject: &str, record: &T) -> Result<String> {
15 if !soql::is_safe_sobject_name(sobject) {
16 return Err(Error::new(ErrorKind::Salesforce {
17 error_code: "INVALID_SOBJECT".to_string(),
18 message: "Invalid SObject name".to_string(),
19 }));
20 }
21 let path = format!("sobjects/{}", sobject);
22 let result: CreateResult = self.client.rest_post(&path, record).await?;
23
24 if result.success {
25 Ok(result.id)
26 } else {
27 let errors: Vec<String> = result.errors.iter().map(|e| e.message.clone()).collect();
28 Err(Error::new(ErrorKind::Salesforce {
29 error_code: "CREATE_FAILED".to_string(),
30 message: errors.join("; "),
31 }))
32 }
33 }
34
35 #[instrument(skip(self))]
39 pub async fn get<T: DeserializeOwned>(
40 &self,
41 sobject: &str,
42 id: &str,
43 fields: Option<&[&str]>,
44 ) -> Result<T> {
45 if !soql::is_safe_sobject_name(sobject) {
46 return Err(Error::new(ErrorKind::Salesforce {
47 error_code: "INVALID_SOBJECT".to_string(),
48 message: "Invalid SObject name".to_string(),
49 }));
50 }
51 if !url_security::is_valid_salesforce_id(id) {
52 return Err(Error::new(ErrorKind::Salesforce {
53 error_code: "INVALID_ID".to_string(),
54 message: "Invalid Salesforce ID format".to_string(),
55 }));
56 }
57 let path = if let Some(fields) = fields {
58 let safe_fields: Vec<&str> = soql::filter_safe_fields(fields.iter().copied()).collect();
60 if safe_fields.is_empty() {
61 return Err(Error::new(ErrorKind::Salesforce {
62 error_code: "INVALID_FIELDS".to_string(),
63 message: "No valid field names provided".to_string(),
64 }));
65 }
66 format!(
67 "sobjects/{}/{}?fields={}",
68 sobject,
69 id,
70 safe_fields.join(",")
71 )
72 } else {
73 format!("sobjects/{}/{}", sobject, id)
74 };
75 self.client.rest_get(&path).await.map_err(Into::into)
76 }
77
78 #[instrument(skip(self, record))]
80 pub async fn update<T: Serialize>(&self, sobject: &str, id: &str, record: &T) -> Result<()> {
81 if !soql::is_safe_sobject_name(sobject) {
82 return Err(Error::new(ErrorKind::Salesforce {
83 error_code: "INVALID_SOBJECT".to_string(),
84 message: "Invalid SObject name".to_string(),
85 }));
86 }
87 if !url_security::is_valid_salesforce_id(id) {
88 return Err(Error::new(ErrorKind::Salesforce {
89 error_code: "INVALID_ID".to_string(),
90 message: "Invalid Salesforce ID format".to_string(),
91 }));
92 }
93 let path = format!("sobjects/{}/{}", sobject, id);
94 self.client
95 .rest_patch(&path, record)
96 .await
97 .map_err(Into::into)
98 }
99
100 #[instrument(skip(self))]
102 pub async fn delete(&self, sobject: &str, id: &str) -> Result<()> {
103 if !soql::is_safe_sobject_name(sobject) {
104 return Err(Error::new(ErrorKind::Salesforce {
105 error_code: "INVALID_SOBJECT".to_string(),
106 message: "Invalid SObject name".to_string(),
107 }));
108 }
109 if !url_security::is_valid_salesforce_id(id) {
110 return Err(Error::new(ErrorKind::Salesforce {
111 error_code: "INVALID_ID".to_string(),
112 message: "Invalid Salesforce ID format".to_string(),
113 }));
114 }
115 let path = format!("sobjects/{}/{}", sobject, id);
116 self.client.rest_delete(&path).await.map_err(Into::into)
117 }
118
119 #[instrument(skip(self, record))]
123 pub async fn upsert<T: Serialize>(
124 &self,
125 sobject: &str,
126 external_id_field: &str,
127 external_id_value: &str,
128 record: &T,
129 ) -> Result<UpsertResult> {
130 if !soql::is_safe_sobject_name(sobject) {
131 return Err(Error::new(ErrorKind::Salesforce {
132 error_code: "INVALID_SOBJECT".to_string(),
133 message: "Invalid SObject name".to_string(),
134 }));
135 }
136 if !soql::is_safe_field_name(external_id_field) {
137 return Err(Error::new(ErrorKind::Salesforce {
138 error_code: "INVALID_FIELD".to_string(),
139 message: "Invalid external ID field name".to_string(),
140 }));
141 }
142 let encoded_value = url_security::encode_param(external_id_value);
144 let path = format!(
145 "sobjects/{}/{}/{}",
146 sobject, external_id_field, encoded_value
147 );
148 let url = self.client.rest_url(&path);
149 let request = self.client.patch(&url).json(record)?;
150 let response = self.client.execute(request).await?;
151
152 let status = response.status();
154 if status == 201 || status == 200 {
155 let result: UpsertResult = response.json().await?;
157 Ok(result)
158 } else if status == 204 {
159 Ok(UpsertResult {
161 id: external_id_value.to_string(),
162 success: true,
163 created: false,
164 errors: vec![],
165 })
166 } else {
167 Err(Error::new(ErrorKind::Salesforce {
168 error_code: "UPSERT_FAILED".to_string(),
169 message: format!("Unexpected status: {}", status),
170 }))
171 }
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::super::SalesforceRestClient;
178
179 #[tokio::test]
180 async fn test_upsert_created_201_wiremock() {
181 use wiremock::matchers::{method, path_regex};
182 use wiremock::{Mock, MockServer, ResponseTemplate};
183
184 let mock_server = MockServer::start().await;
185
186 let body = serde_json::json!({
187 "id": "001xx000003DgAAAS",
188 "success": true,
189 "created": true,
190 "errors": []
191 });
192
193 Mock::given(method("PATCH"))
194 .and(path_regex(".*/sobjects/Account/ExtId__c/.*"))
195 .respond_with(ResponseTemplate::new(201).set_body_json(&body))
196 .mount(&mock_server)
197 .await;
198
199 let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
200 let result = client
201 .upsert(
202 "Account",
203 "ExtId__c",
204 "ext-123",
205 &serde_json::json!({"Name": "Test"}),
206 )
207 .await
208 .expect("Upsert 201 should succeed");
209 assert!(result.created);
210 assert_eq!(result.id, "001xx000003DgAAAS");
211 }
212
213 #[tokio::test]
214 async fn test_upsert_updated_200_wiremock() {
215 use wiremock::matchers::{method, path_regex};
216 use wiremock::{Mock, MockServer, ResponseTemplate};
217
218 let mock_server = MockServer::start().await;
219
220 let body = serde_json::json!({
221 "id": "001xx000003DgAAAS",
222 "success": true,
223 "created": false,
224 "errors": []
225 });
226
227 Mock::given(method("PATCH"))
228 .and(path_regex(".*/sobjects/Account/ExtId__c/.*"))
229 .respond_with(ResponseTemplate::new(200).set_body_json(&body))
230 .mount(&mock_server)
231 .await;
232
233 let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
234 let result = client
235 .upsert(
236 "Account",
237 "ExtId__c",
238 "ext-123",
239 &serde_json::json!({"Name": "Updated"}),
240 )
241 .await
242 .expect("Upsert 200 should succeed");
243 assert!(!result.created);
244 assert_eq!(result.id, "001xx000003DgAAAS");
245 }
246
247 #[tokio::test]
248 async fn test_upsert_updated_204_wiremock() {
249 use wiremock::matchers::{method, path_regex};
250 use wiremock::{Mock, MockServer, ResponseTemplate};
251
252 let mock_server = MockServer::start().await;
253
254 Mock::given(method("PATCH"))
255 .and(path_regex(".*/sobjects/Account/ExtId__c/.*"))
256 .respond_with(ResponseTemplate::new(204))
257 .mount(&mock_server)
258 .await;
259
260 let client = SalesforceRestClient::new(mock_server.uri(), "test-token").unwrap();
261 let result = client
262 .upsert(
263 "Account",
264 "ExtId__c",
265 "ext-123",
266 &serde_json::json!({"Name": "Updated"}),
267 )
268 .await
269 .expect("Upsert 204 should succeed");
270 assert!(!result.created);
271 }
272
273 #[tokio::test]
274 async fn test_upsert_invalid_sobject() {
275 let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
276 let result = client
277 .upsert("Bad'; DROP--", "ExtId__c", "123", &serde_json::json!({}))
278 .await;
279 assert!(result.is_err());
280 assert!(result.unwrap_err().to_string().contains("INVALID_SOBJECT"));
281 }
282
283 #[tokio::test]
284 async fn test_upsert_invalid_field() {
285 let client = SalesforceRestClient::new("https://test.salesforce.com", "token").unwrap();
286 let result = client
287 .upsert("Account", "Bad'; DROP--", "123", &serde_json::json!({}))
288 .await;
289 assert!(result.is_err());
290 assert!(result.unwrap_err().to_string().contains("INVALID_FIELD"));
291 }
292}