Skip to main content

busbar_sf_rest/
client.rs

1//! Salesforce REST API client.
2//!
3//! This client wraps `SalesforceClient` from `sf-client` and provides
4//! typed methods for REST API operations including CRUD, Query, Describe,
5//! Composite, and Collections.
6
7use serde::{de::DeserializeOwned, Serialize};
8use tracing::instrument;
9
10use busbar_sf_client::security::{soql, url as url_security};
11use busbar_sf_client::{ClientConfig, SalesforceClient};
12
13use crate::collections::{CollectionRequest, CollectionResult};
14use crate::composite::{CompositeRequest, CompositeResponse};
15use crate::describe::{DescribeGlobalResult, DescribeSObjectResult};
16use crate::error::{Error, ErrorKind, Result};
17use crate::query::QueryResult;
18use crate::sobject::{CreateResult, UpsertResult};
19
20/// Salesforce REST API client.
21///
22/// Provides typed methods for all REST API operations:
23/// - CRUD operations on SObjects
24/// - SOQL queries with automatic pagination
25/// - SOSL search
26/// - Describe operations
27/// - Composite API
28/// - SObject Collections
29///
30/// # Example
31///
32/// ```rust,ignore
33/// use sf_rest::SalesforceRestClient;
34///
35/// let client = SalesforceRestClient::new(
36///     "https://myorg.my.salesforce.com",
37///     "access_token_here",
38/// )?;
39///
40/// // Query
41/// let accounts: Vec<Account> = client.query_all("SELECT Id, Name FROM Account").await?;
42///
43/// // Create
44/// let id = client.create("Account", &json!({"Name": "New Account"})).await?;
45///
46/// // Update
47/// client.update("Account", &id, &json!({"Name": "Updated"})).await?;
48///
49/// // Delete
50/// client.delete("Account", &id).await?;
51/// ```
52#[derive(Debug, Clone)]
53pub struct SalesforceRestClient {
54    client: SalesforceClient,
55}
56
57impl SalesforceRestClient {
58    /// Create a new REST client with the given instance URL and access token.
59    pub fn new(instance_url: impl Into<String>, access_token: impl Into<String>) -> Result<Self> {
60        let client = SalesforceClient::new(instance_url, access_token)?;
61        Ok(Self { client })
62    }
63
64    /// Create a new REST client with custom HTTP configuration.
65    pub fn with_config(
66        instance_url: impl Into<String>,
67        access_token: impl Into<String>,
68        config: ClientConfig,
69    ) -> Result<Self> {
70        let client = SalesforceClient::with_config(instance_url, access_token, config)?;
71        Ok(Self { client })
72    }
73
74    /// Create a REST client from an existing SalesforceClient.
75    pub fn from_client(client: SalesforceClient) -> Self {
76        Self { client }
77    }
78
79    /// Get the underlying SalesforceClient.
80    pub fn inner(&self) -> &SalesforceClient {
81        &self.client
82    }
83
84    /// Get the instance URL.
85    pub fn instance_url(&self) -> &str {
86        self.client.instance_url()
87    }
88
89    /// Get the API version.
90    pub fn api_version(&self) -> &str {
91        self.client.api_version()
92    }
93
94    /// Set the API version.
95    pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
96        self.client = self.client.with_api_version(version);
97        self
98    }
99
100    // =========================================================================
101    // Describe Operations
102    // =========================================================================
103
104    /// Get a list of all SObjects available in the org.
105    ///
106    /// This is equivalent to calling `/services/data/vXX.0/sobjects/`.
107    #[instrument(skip(self))]
108    pub async fn describe_global(&self) -> Result<DescribeGlobalResult> {
109        self.client.rest_get("sobjects").await.map_err(Into::into)
110    }
111
112    /// Get detailed metadata for a specific SObject.
113    ///
114    /// This is equivalent to calling `/services/data/vXX.0/sobjects/{sobject}/describe`.
115    #[instrument(skip(self))]
116    pub async fn describe_sobject(&self, sobject: &str) -> Result<DescribeSObjectResult> {
117        if !soql::is_safe_sobject_name(sobject) {
118            return Err(Error::new(ErrorKind::Salesforce {
119                error_code: "INVALID_SOBJECT".to_string(),
120                message: "Invalid SObject name".to_string(),
121            }));
122        }
123        let path = format!("sobjects/{}/describe", sobject);
124        self.client.rest_get(&path).await.map_err(Into::into)
125    }
126
127    // =========================================================================
128    // CRUD Operations
129    // =========================================================================
130
131    /// Create a new record.
132    ///
133    /// Returns the ID of the created record.
134    #[instrument(skip(self, record))]
135    pub async fn create<T: Serialize>(&self, sobject: &str, record: &T) -> Result<String> {
136        if !soql::is_safe_sobject_name(sobject) {
137            return Err(Error::new(ErrorKind::Salesforce {
138                error_code: "INVALID_SOBJECT".to_string(),
139                message: "Invalid SObject name".to_string(),
140            }));
141        }
142        let path = format!("sobjects/{}", sobject);
143        let result: CreateResult = self.client.rest_post(&path, record).await?;
144
145        if result.success {
146            Ok(result.id)
147        } else {
148            let errors: Vec<String> = result.errors.iter().map(|e| e.message.clone()).collect();
149            Err(Error::new(ErrorKind::Salesforce {
150                error_code: "CREATE_FAILED".to_string(),
151                message: errors.join("; "),
152            }))
153        }
154    }
155
156    /// Get a record by ID.
157    ///
158    /// Optionally specify which fields to retrieve.
159    #[instrument(skip(self))]
160    pub async fn get<T: DeserializeOwned>(
161        &self,
162        sobject: &str,
163        id: &str,
164        fields: Option<&[&str]>,
165    ) -> Result<T> {
166        if !soql::is_safe_sobject_name(sobject) {
167            return Err(Error::new(ErrorKind::Salesforce {
168                error_code: "INVALID_SOBJECT".to_string(),
169                message: "Invalid SObject name".to_string(),
170            }));
171        }
172        if !url_security::is_valid_salesforce_id(id) {
173            return Err(Error::new(ErrorKind::Salesforce {
174                error_code: "INVALID_ID".to_string(),
175                message: "Invalid Salesforce ID format".to_string(),
176            }));
177        }
178        let path = if let Some(fields) = fields {
179            // Validate and filter field names for safety
180            let safe_fields: Vec<&str> = soql::filter_safe_fields(fields.iter().copied()).collect();
181            if safe_fields.is_empty() {
182                return Err(Error::new(ErrorKind::Salesforce {
183                    error_code: "INVALID_FIELDS".to_string(),
184                    message: "No valid field names provided".to_string(),
185                }));
186            }
187            format!(
188                "sobjects/{}/{}?fields={}",
189                sobject,
190                id,
191                safe_fields.join(",")
192            )
193        } else {
194            format!("sobjects/{}/{}", sobject, id)
195        };
196        self.client.rest_get(&path).await.map_err(Into::into)
197    }
198
199    /// Update a record.
200    #[instrument(skip(self, record))]
201    pub async fn update<T: Serialize>(&self, sobject: &str, id: &str, record: &T) -> Result<()> {
202        if !soql::is_safe_sobject_name(sobject) {
203            return Err(Error::new(ErrorKind::Salesforce {
204                error_code: "INVALID_SOBJECT".to_string(),
205                message: "Invalid SObject name".to_string(),
206            }));
207        }
208        if !url_security::is_valid_salesforce_id(id) {
209            return Err(Error::new(ErrorKind::Salesforce {
210                error_code: "INVALID_ID".to_string(),
211                message: "Invalid Salesforce ID format".to_string(),
212            }));
213        }
214        let path = format!("sobjects/{}/{}", sobject, id);
215        self.client
216            .rest_patch(&path, record)
217            .await
218            .map_err(Into::into)
219    }
220
221    /// Delete a record.
222    #[instrument(skip(self))]
223    pub async fn delete(&self, sobject: &str, id: &str) -> Result<()> {
224        if !soql::is_safe_sobject_name(sobject) {
225            return Err(Error::new(ErrorKind::Salesforce {
226                error_code: "INVALID_SOBJECT".to_string(),
227                message: "Invalid SObject name".to_string(),
228            }));
229        }
230        if !url_security::is_valid_salesforce_id(id) {
231            return Err(Error::new(ErrorKind::Salesforce {
232                error_code: "INVALID_ID".to_string(),
233                message: "Invalid Salesforce ID format".to_string(),
234            }));
235        }
236        let path = format!("sobjects/{}/{}", sobject, id);
237        self.client.rest_delete(&path).await.map_err(Into::into)
238    }
239
240    /// Upsert a record using an external ID field.
241    ///
242    /// Creates the record if it doesn't exist, updates it if it does.
243    #[instrument(skip(self, record))]
244    pub async fn upsert<T: Serialize>(
245        &self,
246        sobject: &str,
247        external_id_field: &str,
248        external_id_value: &str,
249        record: &T,
250    ) -> Result<UpsertResult> {
251        if !soql::is_safe_sobject_name(sobject) {
252            return Err(Error::new(ErrorKind::Salesforce {
253                error_code: "INVALID_SOBJECT".to_string(),
254                message: "Invalid SObject name".to_string(),
255            }));
256        }
257        if !soql::is_safe_field_name(external_id_field) {
258            return Err(Error::new(ErrorKind::Salesforce {
259                error_code: "INVALID_FIELD".to_string(),
260                message: "Invalid external ID field name".to_string(),
261            }));
262        }
263        // URL-encode the external ID value to handle special characters
264        let encoded_value = url_security::encode_param(external_id_value);
265        let path = format!(
266            "sobjects/{}/{}/{}",
267            sobject, external_id_field, encoded_value
268        );
269        let url = self.client.rest_url(&path);
270        let request = self.client.patch(&url).json(record)?;
271        let response = self.client.execute(request).await?;
272
273        // Upsert returns 201 Created or 204 No Content
274        let status = response.status();
275        if status == 201 {
276            // Created - response has the ID
277            let result: UpsertResult = response.json().await?;
278            Ok(result)
279        } else if status == 204 {
280            // Updated - no response body
281            Ok(UpsertResult {
282                id: external_id_value.to_string(),
283                success: true,
284                created: false,
285                errors: vec![],
286            })
287        } else {
288            Err(Error::new(ErrorKind::Salesforce {
289                error_code: "UPSERT_FAILED".to_string(),
290                message: format!("Unexpected status: {}", status),
291            }))
292        }
293    }
294
295    // =========================================================================
296    // Query Operations
297    // =========================================================================
298
299    /// Execute a SOQL query.
300    ///
301    /// Returns the first page of results. Use `query_all` for automatic pagination.
302    ///
303    /// # Security
304    ///
305    /// **IMPORTANT**: If you are including user-provided values in the WHERE clause,
306    /// you MUST escape them to prevent SOQL injection attacks. Use the security utilities:
307    ///
308    /// ```rust,ignore
309    /// use busbar_sf_client::security::soql;
310    ///
311    /// // WRONG - vulnerable to injection:
312    /// let query = format!("SELECT Id FROM Account WHERE Name = '{}'", user_input);
313    ///
314    /// // CORRECT - properly escaped:
315    /// let safe_value = soql::escape_string(user_input);
316    /// let query = format!("SELECT Id FROM Account WHERE Name = '{}'", safe_value);
317    /// ```
318    #[instrument(skip(self))]
319    pub async fn query<T: DeserializeOwned>(&self, soql: &str) -> Result<QueryResult<T>> {
320        self.client.query(soql).await.map_err(Into::into)
321    }
322
323    /// Execute a SOQL query and return all results (automatic pagination).
324    ///
325    /// # Security
326    ///
327    /// **IMPORTANT**: Escape user-provided values with `busbar_sf_client::security::soql::escape_string()`
328    /// to prevent SOQL injection attacks. See `query()` for examples.
329    #[instrument(skip(self))]
330    pub async fn query_all<T: DeserializeOwned + Clone>(&self, soql: &str) -> Result<Vec<T>> {
331        self.client.query_all(soql).await.map_err(Into::into)
332    }
333
334    /// Execute a SOQL query including deleted/archived records.
335    ///
336    /// # Security
337    ///
338    /// **IMPORTANT**: Escape user-provided values with `busbar_sf_client::security::soql::escape_string()`
339    /// to prevent SOQL injection attacks. See `query()` for examples.
340    #[instrument(skip(self))]
341    pub async fn query_all_including_deleted<T: DeserializeOwned>(
342        &self,
343        soql: &str,
344    ) -> Result<QueryResult<T>> {
345        let encoded = urlencoding::encode(soql);
346        let url = format!(
347            "{}/services/data/v{}/queryAll?q={}",
348            self.client.instance_url(),
349            self.client.api_version(),
350            encoded
351        );
352        self.client.get_json(&url).await.map_err(Into::into)
353    }
354
355    /// Fetch the next page of query results.
356    #[instrument(skip(self))]
357    pub async fn query_more<T: DeserializeOwned>(
358        &self,
359        next_records_url: &str,
360    ) -> Result<QueryResult<T>> {
361        self.client
362            .get_json(next_records_url)
363            .await
364            .map_err(Into::into)
365    }
366
367    // =========================================================================
368    // Search Operations (SOSL)
369    // =========================================================================
370
371    /// Execute a SOSL search.
372    ///
373    /// # Security
374    ///
375    /// **IMPORTANT**: If you are including user-provided values in the search term,
376    /// you MUST escape them. Use `busbar_sf_client::security::soql::escape_string()`
377    /// for string values in SOSL queries.
378    #[instrument(skip(self))]
379    pub async fn search<T: DeserializeOwned>(&self, sosl: &str) -> Result<SearchResult<T>> {
380        let encoded = urlencoding::encode(sosl);
381        let url = format!(
382            "{}/services/data/v{}/search?q={}",
383            self.client.instance_url(),
384            self.client.api_version(),
385            encoded
386        );
387        self.client.get_json(&url).await.map_err(Into::into)
388    }
389
390    // =========================================================================
391    // Composite API
392    // =========================================================================
393
394    /// Execute a composite request with multiple subrequests.
395    ///
396    /// The composite API allows up to 25 subrequests in a single API call.
397    #[instrument(skip(self, request))]
398    pub async fn composite(&self, request: &CompositeRequest) -> Result<CompositeResponse> {
399        self.client
400            .rest_post("composite", request)
401            .await
402            .map_err(Into::into)
403    }
404
405    // =========================================================================
406    // SObject Collections
407    // =========================================================================
408
409    /// Create multiple records in a single request (up to 200).
410    #[instrument(skip(self, records))]
411    pub async fn create_multiple<T: Serialize>(
412        &self,
413        sobject: &str,
414        records: &[T],
415        all_or_none: bool,
416    ) -> Result<Vec<CollectionResult>> {
417        if !soql::is_safe_sobject_name(sobject) {
418            return Err(Error::new(ErrorKind::Salesforce {
419                error_code: "INVALID_SOBJECT".to_string(),
420                message: "Invalid SObject name".to_string(),
421            }));
422        }
423        let request = CollectionRequest {
424            all_or_none,
425            records: records
426                .iter()
427                .map(|r| {
428                    let mut value = serde_json::to_value(r).unwrap_or(serde_json::Value::Null);
429                    if let serde_json::Value::Object(ref mut map) = value {
430                        map.insert(
431                            "attributes".to_string(),
432                            serde_json::json!({"type": sobject}),
433                        );
434                    }
435                    value
436                })
437                .collect(),
438        };
439        self.client
440            .rest_post("composite/sobjects", &request)
441            .await
442            .map_err(Into::into)
443    }
444
445    /// Update multiple records in a single request (up to 200).
446    #[instrument(skip(self, records))]
447    pub async fn update_multiple<T: Serialize>(
448        &self,
449        sobject: &str,
450        records: &[(String, T)], // (id, record)
451        all_or_none: bool,
452    ) -> Result<Vec<CollectionResult>> {
453        if !soql::is_safe_sobject_name(sobject) {
454            return Err(Error::new(ErrorKind::Salesforce {
455                error_code: "INVALID_SOBJECT".to_string(),
456                message: "Invalid SObject name".to_string(),
457            }));
458        }
459        // Validate all IDs
460        for (id, _) in records {
461            if !url_security::is_valid_salesforce_id(id) {
462                return Err(Error::new(ErrorKind::Salesforce {
463                    error_code: "INVALID_ID".to_string(),
464                    message: "Invalid Salesforce ID format".to_string(),
465                }));
466            }
467        }
468        let request = CollectionRequest {
469            all_or_none,
470            records: records
471                .iter()
472                .map(|(id, r)| {
473                    let mut value = serde_json::to_value(r).unwrap_or(serde_json::Value::Null);
474                    if let serde_json::Value::Object(ref mut map) = value {
475                        map.insert(
476                            "attributes".to_string(),
477                            serde_json::json!({"type": sobject}),
478                        );
479                        map.insert("Id".to_string(), serde_json::json!(id));
480                    }
481                    value
482                })
483                .collect(),
484        };
485
486        let url = self.client.rest_url("composite/sobjects");
487        let request_builder = self.client.patch(&url).json(&request)?;
488        let response = self.client.execute(request_builder).await?;
489        response.json().await.map_err(Into::into)
490    }
491
492    /// Delete multiple records in a single request (up to 200).
493    #[instrument(skip(self))]
494    pub async fn delete_multiple(
495        &self,
496        ids: &[&str],
497        all_or_none: bool,
498    ) -> Result<Vec<CollectionResult>> {
499        // Validate all IDs before proceeding
500        for id in ids {
501            if !url_security::is_valid_salesforce_id(id) {
502                return Err(Error::new(ErrorKind::Salesforce {
503                    error_code: "INVALID_ID".to_string(),
504                    message: "Invalid Salesforce ID format".to_string(),
505                }));
506            }
507        }
508        let ids_param = ids.join(",");
509        let url = format!(
510            "{}/services/data/v{}/composite/sobjects?ids={}&allOrNone={}",
511            self.client.instance_url(),
512            self.client.api_version(),
513            ids_param,
514            all_or_none
515        );
516        let request = self.client.delete(&url);
517        let response = self.client.execute(request).await?;
518        response.json().await.map_err(Into::into)
519    }
520
521    /// Get multiple records by ID in a single request (up to 2000).
522    #[instrument(skip(self))]
523    pub async fn get_multiple<T: DeserializeOwned>(
524        &self,
525        sobject: &str,
526        ids: &[&str],
527        fields: &[&str],
528    ) -> Result<Vec<T>> {
529        if !soql::is_safe_sobject_name(sobject) {
530            return Err(Error::new(ErrorKind::Salesforce {
531                error_code: "INVALID_SOBJECT".to_string(),
532                message: "Invalid SObject name".to_string(),
533            }));
534        }
535        // Validate all IDs
536        for id in ids {
537            if !url_security::is_valid_salesforce_id(id) {
538                return Err(Error::new(ErrorKind::Salesforce {
539                    error_code: "INVALID_ID".to_string(),
540                    message: "Invalid Salesforce ID format".to_string(),
541                }));
542            }
543        }
544        // Validate and filter field names
545        let safe_fields: Vec<&str> = soql::filter_safe_fields(fields.iter().copied()).collect();
546        if safe_fields.is_empty() {
547            return Err(Error::new(ErrorKind::Salesforce {
548                error_code: "INVALID_FIELDS".to_string(),
549                message: "No valid field names provided".to_string(),
550            }));
551        }
552        let ids_param = ids.join(",");
553        let fields_param = safe_fields.join(",");
554        let url = format!(
555            "{}/services/data/v{}/composite/sobjects/{}?ids={}&fields={}",
556            self.client.instance_url(),
557            self.client.api_version(),
558            sobject,
559            ids_param,
560            fields_param
561        );
562        self.client.get_json(&url).await.map_err(Into::into)
563    }
564
565    // =========================================================================
566    // Limits
567    // =========================================================================
568
569    /// Get API limits for the org.
570    #[instrument(skip(self))]
571    pub async fn limits(&self) -> Result<serde_json::Value> {
572        self.client.rest_get("limits").await.map_err(Into::into)
573    }
574
575    // =========================================================================
576    // API Versions
577    // =========================================================================
578
579    /// Get available API versions.
580    #[instrument(skip(self))]
581    pub async fn versions(&self) -> Result<Vec<ApiVersion>> {
582        let url = format!("{}/services/data", self.client.instance_url());
583        self.client.get_json(&url).await.map_err(Into::into)
584    }
585}
586
587/// Result of a SOSL search.
588#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
589pub struct SearchResult<T> {
590    #[serde(rename = "searchRecords")]
591    pub search_records: Vec<T>,
592}
593
594/// API version information.
595#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
596pub struct ApiVersion {
597    pub version: String,
598    pub label: String,
599    pub url: String,
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    #[test]
607    fn test_client_creation() {
608        let client = SalesforceRestClient::new("https://na1.salesforce.com", "token123").unwrap();
609
610        assert_eq!(client.instance_url(), "https://na1.salesforce.com");
611        assert_eq!(client.api_version(), "62.0");
612    }
613
614    #[test]
615    fn test_api_version_override() {
616        let client = SalesforceRestClient::new("https://na1.salesforce.com", "token")
617            .unwrap()
618            .with_api_version("60.0");
619
620        assert_eq!(client.api_version(), "60.0");
621    }
622}