1use serde::{de::DeserializeOwned, Serialize};
8use tracing::instrument;
9
10use busbar_sf_client::security::{soql, url as url_security};
11use busbar_sf_client::{ClientConfig, SalesforceClient};
12
13use crate::collections::{CollectionRequest, CollectionResult};
14use crate::composite::{CompositeRequest, CompositeResponse};
15use crate::describe::{DescribeGlobalResult, DescribeSObjectResult};
16use crate::error::{Error, ErrorKind, Result};
17use crate::query::QueryResult;
18use crate::sobject::{CreateResult, UpsertResult};
19
20#[derive(Debug, Clone)]
53pub struct SalesforceRestClient {
54 client: SalesforceClient,
55}
56
57impl SalesforceRestClient {
58 pub fn new(instance_url: impl Into<String>, access_token: impl Into<String>) -> Result<Self> {
60 let client = SalesforceClient::new(instance_url, access_token)?;
61 Ok(Self { client })
62 }
63
64 pub fn with_config(
66 instance_url: impl Into<String>,
67 access_token: impl Into<String>,
68 config: ClientConfig,
69 ) -> Result<Self> {
70 let client = SalesforceClient::with_config(instance_url, access_token, config)?;
71 Ok(Self { client })
72 }
73
74 pub fn from_client(client: SalesforceClient) -> Self {
76 Self { client }
77 }
78
79 pub fn inner(&self) -> &SalesforceClient {
81 &self.client
82 }
83
84 pub fn instance_url(&self) -> &str {
86 self.client.instance_url()
87 }
88
89 pub fn api_version(&self) -> &str {
91 self.client.api_version()
92 }
93
94 pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
96 self.client = self.client.with_api_version(version);
97 self
98 }
99
100 #[instrument(skip(self))]
108 pub async fn describe_global(&self) -> Result<DescribeGlobalResult> {
109 self.client.rest_get("sobjects").await.map_err(Into::into)
110 }
111
112 #[instrument(skip(self))]
116 pub async fn describe_sobject(&self, sobject: &str) -> Result<DescribeSObjectResult> {
117 if !soql::is_safe_sobject_name(sobject) {
118 return Err(Error::new(ErrorKind::Salesforce {
119 error_code: "INVALID_SOBJECT".to_string(),
120 message: "Invalid SObject name".to_string(),
121 }));
122 }
123 let path = format!("sobjects/{}/describe", sobject);
124 self.client.rest_get(&path).await.map_err(Into::into)
125 }
126
127 #[instrument(skip(self, record))]
135 pub async fn create<T: Serialize>(&self, sobject: &str, record: &T) -> Result<String> {
136 if !soql::is_safe_sobject_name(sobject) {
137 return Err(Error::new(ErrorKind::Salesforce {
138 error_code: "INVALID_SOBJECT".to_string(),
139 message: "Invalid SObject name".to_string(),
140 }));
141 }
142 let path = format!("sobjects/{}", sobject);
143 let result: CreateResult = self.client.rest_post(&path, record).await?;
144
145 if result.success {
146 Ok(result.id)
147 } else {
148 let errors: Vec<String> = result.errors.iter().map(|e| e.message.clone()).collect();
149 Err(Error::new(ErrorKind::Salesforce {
150 error_code: "CREATE_FAILED".to_string(),
151 message: errors.join("; "),
152 }))
153 }
154 }
155
156 #[instrument(skip(self))]
160 pub async fn get<T: DeserializeOwned>(
161 &self,
162 sobject: &str,
163 id: &str,
164 fields: Option<&[&str]>,
165 ) -> Result<T> {
166 if !soql::is_safe_sobject_name(sobject) {
167 return Err(Error::new(ErrorKind::Salesforce {
168 error_code: "INVALID_SOBJECT".to_string(),
169 message: "Invalid SObject name".to_string(),
170 }));
171 }
172 if !url_security::is_valid_salesforce_id(id) {
173 return Err(Error::new(ErrorKind::Salesforce {
174 error_code: "INVALID_ID".to_string(),
175 message: "Invalid Salesforce ID format".to_string(),
176 }));
177 }
178 let path = if let Some(fields) = fields {
179 let safe_fields: Vec<&str> = soql::filter_safe_fields(fields.iter().copied()).collect();
181 if safe_fields.is_empty() {
182 return Err(Error::new(ErrorKind::Salesforce {
183 error_code: "INVALID_FIELDS".to_string(),
184 message: "No valid field names provided".to_string(),
185 }));
186 }
187 format!(
188 "sobjects/{}/{}?fields={}",
189 sobject,
190 id,
191 safe_fields.join(",")
192 )
193 } else {
194 format!("sobjects/{}/{}", sobject, id)
195 };
196 self.client.rest_get(&path).await.map_err(Into::into)
197 }
198
199 #[instrument(skip(self, record))]
201 pub async fn update<T: Serialize>(&self, sobject: &str, id: &str, record: &T) -> Result<()> {
202 if !soql::is_safe_sobject_name(sobject) {
203 return Err(Error::new(ErrorKind::Salesforce {
204 error_code: "INVALID_SOBJECT".to_string(),
205 message: "Invalid SObject name".to_string(),
206 }));
207 }
208 if !url_security::is_valid_salesforce_id(id) {
209 return Err(Error::new(ErrorKind::Salesforce {
210 error_code: "INVALID_ID".to_string(),
211 message: "Invalid Salesforce ID format".to_string(),
212 }));
213 }
214 let path = format!("sobjects/{}/{}", sobject, id);
215 self.client
216 .rest_patch(&path, record)
217 .await
218 .map_err(Into::into)
219 }
220
221 #[instrument(skip(self))]
223 pub async fn delete(&self, sobject: &str, id: &str) -> Result<()> {
224 if !soql::is_safe_sobject_name(sobject) {
225 return Err(Error::new(ErrorKind::Salesforce {
226 error_code: "INVALID_SOBJECT".to_string(),
227 message: "Invalid SObject name".to_string(),
228 }));
229 }
230 if !url_security::is_valid_salesforce_id(id) {
231 return Err(Error::new(ErrorKind::Salesforce {
232 error_code: "INVALID_ID".to_string(),
233 message: "Invalid Salesforce ID format".to_string(),
234 }));
235 }
236 let path = format!("sobjects/{}/{}", sobject, id);
237 self.client.rest_delete(&path).await.map_err(Into::into)
238 }
239
240 #[instrument(skip(self, record))]
244 pub async fn upsert<T: Serialize>(
245 &self,
246 sobject: &str,
247 external_id_field: &str,
248 external_id_value: &str,
249 record: &T,
250 ) -> Result<UpsertResult> {
251 if !soql::is_safe_sobject_name(sobject) {
252 return Err(Error::new(ErrorKind::Salesforce {
253 error_code: "INVALID_SOBJECT".to_string(),
254 message: "Invalid SObject name".to_string(),
255 }));
256 }
257 if !soql::is_safe_field_name(external_id_field) {
258 return Err(Error::new(ErrorKind::Salesforce {
259 error_code: "INVALID_FIELD".to_string(),
260 message: "Invalid external ID field name".to_string(),
261 }));
262 }
263 let encoded_value = url_security::encode_param(external_id_value);
265 let path = format!(
266 "sobjects/{}/{}/{}",
267 sobject, external_id_field, encoded_value
268 );
269 let url = self.client.rest_url(&path);
270 let request = self.client.patch(&url).json(record)?;
271 let response = self.client.execute(request).await?;
272
273 let status = response.status();
275 if status == 201 {
276 let result: UpsertResult = response.json().await?;
278 Ok(result)
279 } else if status == 204 {
280 Ok(UpsertResult {
282 id: external_id_value.to_string(),
283 success: true,
284 created: false,
285 errors: vec![],
286 })
287 } else {
288 Err(Error::new(ErrorKind::Salesforce {
289 error_code: "UPSERT_FAILED".to_string(),
290 message: format!("Unexpected status: {}", status),
291 }))
292 }
293 }
294
295 #[instrument(skip(self))]
319 pub async fn query<T: DeserializeOwned>(&self, soql: &str) -> Result<QueryResult<T>> {
320 self.client.query(soql).await.map_err(Into::into)
321 }
322
323 #[instrument(skip(self))]
330 pub async fn query_all<T: DeserializeOwned + Clone>(&self, soql: &str) -> Result<Vec<T>> {
331 self.client.query_all(soql).await.map_err(Into::into)
332 }
333
334 #[instrument(skip(self))]
341 pub async fn query_all_including_deleted<T: DeserializeOwned>(
342 &self,
343 soql: &str,
344 ) -> Result<QueryResult<T>> {
345 let encoded = urlencoding::encode(soql);
346 let url = format!(
347 "{}/services/data/v{}/queryAll?q={}",
348 self.client.instance_url(),
349 self.client.api_version(),
350 encoded
351 );
352 self.client.get_json(&url).await.map_err(Into::into)
353 }
354
355 #[instrument(skip(self))]
357 pub async fn query_more<T: DeserializeOwned>(
358 &self,
359 next_records_url: &str,
360 ) -> Result<QueryResult<T>> {
361 self.client
362 .get_json(next_records_url)
363 .await
364 .map_err(Into::into)
365 }
366
367 #[instrument(skip(self))]
379 pub async fn search<T: DeserializeOwned>(&self, sosl: &str) -> Result<SearchResult<T>> {
380 let encoded = urlencoding::encode(sosl);
381 let url = format!(
382 "{}/services/data/v{}/search?q={}",
383 self.client.instance_url(),
384 self.client.api_version(),
385 encoded
386 );
387 self.client.get_json(&url).await.map_err(Into::into)
388 }
389
390 #[instrument(skip(self, request))]
398 pub async fn composite(&self, request: &CompositeRequest) -> Result<CompositeResponse> {
399 self.client
400 .rest_post("composite", request)
401 .await
402 .map_err(Into::into)
403 }
404
405 #[instrument(skip(self, records))]
411 pub async fn create_multiple<T: Serialize>(
412 &self,
413 sobject: &str,
414 records: &[T],
415 all_or_none: bool,
416 ) -> Result<Vec<CollectionResult>> {
417 if !soql::is_safe_sobject_name(sobject) {
418 return Err(Error::new(ErrorKind::Salesforce {
419 error_code: "INVALID_SOBJECT".to_string(),
420 message: "Invalid SObject name".to_string(),
421 }));
422 }
423 let request = CollectionRequest {
424 all_or_none,
425 records: records
426 .iter()
427 .map(|r| {
428 let mut value = serde_json::to_value(r).unwrap_or(serde_json::Value::Null);
429 if let serde_json::Value::Object(ref mut map) = value {
430 map.insert(
431 "attributes".to_string(),
432 serde_json::json!({"type": sobject}),
433 );
434 }
435 value
436 })
437 .collect(),
438 };
439 self.client
440 .rest_post("composite/sobjects", &request)
441 .await
442 .map_err(Into::into)
443 }
444
445 #[instrument(skip(self, records))]
447 pub async fn update_multiple<T: Serialize>(
448 &self,
449 sobject: &str,
450 records: &[(String, T)], all_or_none: bool,
452 ) -> Result<Vec<CollectionResult>> {
453 if !soql::is_safe_sobject_name(sobject) {
454 return Err(Error::new(ErrorKind::Salesforce {
455 error_code: "INVALID_SOBJECT".to_string(),
456 message: "Invalid SObject name".to_string(),
457 }));
458 }
459 for (id, _) in records {
461 if !url_security::is_valid_salesforce_id(id) {
462 return Err(Error::new(ErrorKind::Salesforce {
463 error_code: "INVALID_ID".to_string(),
464 message: "Invalid Salesforce ID format".to_string(),
465 }));
466 }
467 }
468 let request = CollectionRequest {
469 all_or_none,
470 records: records
471 .iter()
472 .map(|(id, r)| {
473 let mut value = serde_json::to_value(r).unwrap_or(serde_json::Value::Null);
474 if let serde_json::Value::Object(ref mut map) = value {
475 map.insert(
476 "attributes".to_string(),
477 serde_json::json!({"type": sobject}),
478 );
479 map.insert("Id".to_string(), serde_json::json!(id));
480 }
481 value
482 })
483 .collect(),
484 };
485
486 let url = self.client.rest_url("composite/sobjects");
487 let request_builder = self.client.patch(&url).json(&request)?;
488 let response = self.client.execute(request_builder).await?;
489 response.json().await.map_err(Into::into)
490 }
491
492 #[instrument(skip(self))]
494 pub async fn delete_multiple(
495 &self,
496 ids: &[&str],
497 all_or_none: bool,
498 ) -> Result<Vec<CollectionResult>> {
499 for id in ids {
501 if !url_security::is_valid_salesforce_id(id) {
502 return Err(Error::new(ErrorKind::Salesforce {
503 error_code: "INVALID_ID".to_string(),
504 message: "Invalid Salesforce ID format".to_string(),
505 }));
506 }
507 }
508 let ids_param = ids.join(",");
509 let url = format!(
510 "{}/services/data/v{}/composite/sobjects?ids={}&allOrNone={}",
511 self.client.instance_url(),
512 self.client.api_version(),
513 ids_param,
514 all_or_none
515 );
516 let request = self.client.delete(&url);
517 let response = self.client.execute(request).await?;
518 response.json().await.map_err(Into::into)
519 }
520
521 #[instrument(skip(self))]
523 pub async fn get_multiple<T: DeserializeOwned>(
524 &self,
525 sobject: &str,
526 ids: &[&str],
527 fields: &[&str],
528 ) -> Result<Vec<T>> {
529 if !soql::is_safe_sobject_name(sobject) {
530 return Err(Error::new(ErrorKind::Salesforce {
531 error_code: "INVALID_SOBJECT".to_string(),
532 message: "Invalid SObject name".to_string(),
533 }));
534 }
535 for id in ids {
537 if !url_security::is_valid_salesforce_id(id) {
538 return Err(Error::new(ErrorKind::Salesforce {
539 error_code: "INVALID_ID".to_string(),
540 message: "Invalid Salesforce ID format".to_string(),
541 }));
542 }
543 }
544 let safe_fields: Vec<&str> = soql::filter_safe_fields(fields.iter().copied()).collect();
546 if safe_fields.is_empty() {
547 return Err(Error::new(ErrorKind::Salesforce {
548 error_code: "INVALID_FIELDS".to_string(),
549 message: "No valid field names provided".to_string(),
550 }));
551 }
552 let ids_param = ids.join(",");
553 let fields_param = safe_fields.join(",");
554 let url = format!(
555 "{}/services/data/v{}/composite/sobjects/{}?ids={}&fields={}",
556 self.client.instance_url(),
557 self.client.api_version(),
558 sobject,
559 ids_param,
560 fields_param
561 );
562 self.client.get_json(&url).await.map_err(Into::into)
563 }
564
565 #[instrument(skip(self))]
571 pub async fn limits(&self) -> Result<serde_json::Value> {
572 self.client.rest_get("limits").await.map_err(Into::into)
573 }
574
575 #[instrument(skip(self))]
581 pub async fn versions(&self) -> Result<Vec<ApiVersion>> {
582 let url = format!("{}/services/data", self.client.instance_url());
583 self.client.get_json(&url).await.map_err(Into::into)
584 }
585}
586
587#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
589pub struct SearchResult<T> {
590 #[serde(rename = "searchRecords")]
591 pub search_records: Vec<T>,
592}
593
594#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
596pub struct ApiVersion {
597 pub version: String,
598 pub label: String,
599 pub url: String,
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn test_client_creation() {
608 let client = SalesforceRestClient::new("https://na1.salesforce.com", "token123").unwrap();
609
610 assert_eq!(client.instance_url(), "https://na1.salesforce.com");
611 assert_eq!(client.api_version(), "62.0");
612 }
613
614 #[test]
615 fn test_api_version_override() {
616 let client = SalesforceRestClient::new("https://na1.salesforce.com", "token")
617 .unwrap()
618 .with_api_version("60.0");
619
620 assert_eq!(client.api_version(), "60.0");
621 }
622}