Skip to main content

busbar_sf_tooling/
client.rs

1//! Salesforce Tooling API client.
2//!
3//! This client wraps `SalesforceClient` from `sf-client` and provides
4//! typed methods for Tooling API operations.
5
6use serde::de::DeserializeOwned;
7use tracing::instrument;
8
9use busbar_sf_client::security::{soql, url as url_security};
10use busbar_sf_client::{ClientConfig, QueryResult, SalesforceClient};
11
12use crate::error::{Error, ErrorKind, Result};
13use crate::types::*;
14
15/// Salesforce Tooling API client.
16///
17/// Provides typed methods for Tooling API operations:
18/// - Execute anonymous Apex
19/// - Query Apex classes, triggers, and logs
20/// - Manage debug logs and trace flags
21/// - Code coverage information
22///
23/// # Example
24///
25/// ```rust,ignore
26/// use sf_tooling::ToolingClient;
27///
28/// let client = ToolingClient::new(
29///     "https://myorg.my.salesforce.com",
30///     "access_token_here",
31/// )?;
32///
33/// // Execute anonymous Apex
34/// let result = client.execute_anonymous("System.debug('Hello');").await?;
35///
36/// // Query Apex classes
37/// let classes: Vec<ApexClass> = client
38///     .query_all("SELECT Id, Name FROM ApexClass")
39///     .await?;
40/// ```
41#[derive(Debug, Clone)]
42pub struct ToolingClient {
43    client: SalesforceClient,
44}
45
46impl ToolingClient {
47    /// Create a new Tooling API client with the given instance URL and access token.
48    pub fn new(instance_url: impl Into<String>, access_token: impl Into<String>) -> Result<Self> {
49        let client = SalesforceClient::new(instance_url, access_token)?;
50        Ok(Self { client })
51    }
52
53    /// Create a new Tooling API client with custom HTTP configuration.
54    pub fn with_config(
55        instance_url: impl Into<String>,
56        access_token: impl Into<String>,
57        config: ClientConfig,
58    ) -> Result<Self> {
59        let client = SalesforceClient::with_config(instance_url, access_token, config)?;
60        Ok(Self { client })
61    }
62
63    /// Create a Tooling client from an existing SalesforceClient.
64    pub fn from_client(client: SalesforceClient) -> Self {
65        Self { client }
66    }
67
68    /// Get the underlying SalesforceClient.
69    pub fn inner(&self) -> &SalesforceClient {
70        &self.client
71    }
72
73    /// Get the instance URL.
74    pub fn instance_url(&self) -> &str {
75        self.client.instance_url()
76    }
77
78    /// Get the API version.
79    pub fn api_version(&self) -> &str {
80        self.client.api_version()
81    }
82
83    /// Set the API version.
84    pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
85        self.client = self.client.with_api_version(version);
86        self
87    }
88
89    // =========================================================================
90    // Query Operations
91    // =========================================================================
92
93    /// Execute a SOQL query against the Tooling API.
94    ///
95    /// Returns the first page of results. Use `query_all` for automatic pagination.
96    ///
97    /// # Security
98    ///
99    /// **IMPORTANT**: If you are including user-provided values in the WHERE clause,
100    /// you MUST escape them to prevent SOQL injection attacks:
101    ///
102    /// ```rust,ignore
103    /// use busbar_sf_client::security::soql;
104    ///
105    /// // CORRECT - properly escaped:
106    /// let safe_name = soql::escape_string(user_input);
107    /// let query = format!("SELECT Id FROM ApexClass WHERE Name = '{}'", safe_name);
108    /// ```
109    #[instrument(skip(self))]
110    pub async fn query<T: DeserializeOwned>(&self, soql: &str) -> Result<QueryResult<T>> {
111        self.client.tooling_query(soql).await.map_err(Into::into)
112    }
113
114    /// Execute a SOQL query and return all results (automatic pagination).
115    ///
116    /// # Security
117    ///
118    /// **IMPORTANT**: Escape user-provided values with `busbar_sf_client::security::soql::escape_string()`
119    /// to prevent SOQL injection attacks. See `query()` for examples.
120    #[instrument(skip(self))]
121    pub async fn query_all<T: DeserializeOwned + Clone>(&self, soql: &str) -> Result<Vec<T>> {
122        self.client
123            .tooling_query_all(soql)
124            .await
125            .map_err(Into::into)
126    }
127
128    // =========================================================================
129    // Execute Anonymous
130    // =========================================================================
131
132    /// Execute anonymous Apex code.
133    ///
134    /// # Example
135    ///
136    /// ```rust,ignore
137    /// let result = client.execute_anonymous("System.debug('Hello World');").await?;
138    /// if result.success {
139    ///     println!("Execution successful");
140    /// } else if let Some(err) = result.compile_problem {
141    ///     println!("Compilation error: {}", err);
142    /// }
143    /// ```
144    #[instrument(skip(self))]
145    pub async fn execute_anonymous(&self, apex_code: &str) -> Result<ExecuteAnonymousResult> {
146        let encoded = urlencoding::encode(apex_code);
147        let url = format!(
148            "{}/services/data/v{}/tooling/executeAnonymous/?anonymousBody={}",
149            self.client.instance_url(),
150            self.client.api_version(),
151            encoded
152        );
153
154        let result: ExecuteAnonymousResult = self.client.get_json(&url).await?;
155
156        // Check for compilation or execution errors
157        if !result.compiled {
158            if let Some(ref problem) = result.compile_problem {
159                return Err(Error::new(ErrorKind::ApexCompilation(problem.clone())));
160            }
161        }
162
163        if !result.success {
164            if let Some(ref message) = result.exception_message {
165                return Err(Error::new(ErrorKind::ApexExecution(message.clone())));
166            }
167        }
168
169        Ok(result)
170    }
171
172    // =========================================================================
173    // Apex Class Operations
174    // =========================================================================
175
176    /// Get all Apex classes in the org.
177    #[instrument(skip(self))]
178    pub async fn get_apex_classes(&self) -> Result<Vec<ApexClass>> {
179        self.query_all("SELECT Id, Name, Body, Status, IsValid, ApiVersion, NamespacePrefix, CreatedDate, LastModifiedDate FROM ApexClass")
180            .await
181    }
182
183    /// Get an Apex class by name.
184    #[instrument(skip(self))]
185    pub async fn get_apex_class_by_name(&self, name: &str) -> Result<Option<ApexClass>> {
186        let safe_name = soql::escape_string(name);
187        let soql = format!(
188            "SELECT Id, Name, Body, Status, IsValid, ApiVersion, NamespacePrefix, CreatedDate, LastModifiedDate FROM ApexClass WHERE Name = '{}'",
189            safe_name
190        );
191        let mut classes: Vec<ApexClass> = self.query_all(&soql).await?;
192        Ok(classes.pop())
193    }
194
195    /// Get an Apex class by ID.
196    #[instrument(skip(self))]
197    pub async fn get_apex_class(&self, id: &str) -> Result<ApexClass> {
198        if !url_security::is_valid_salesforce_id(id) {
199            return Err(Error::new(ErrorKind::Salesforce {
200                error_code: "INVALID_ID".to_string(),
201                message: "Invalid Salesforce ID format".to_string(),
202            }));
203        }
204        let path = format!("sobjects/ApexClass/{}", id);
205        self.client.tooling_get(&path).await.map_err(Into::into)
206    }
207
208    // =========================================================================
209    // Apex Trigger Operations
210    // =========================================================================
211
212    /// Get all Apex triggers in the org.
213    #[instrument(skip(self))]
214    pub async fn get_apex_triggers(&self) -> Result<Vec<ApexTrigger>> {
215        self.query_all(
216            "SELECT Id, Name, Body, Status, IsValid, ApiVersion, TableEnumOrId FROM ApexTrigger",
217        )
218        .await
219    }
220
221    /// Get an Apex trigger by name.
222    #[instrument(skip(self))]
223    pub async fn get_apex_trigger_by_name(&self, name: &str) -> Result<Option<ApexTrigger>> {
224        let safe_name = soql::escape_string(name);
225        let soql = format!(
226            "SELECT Id, Name, Body, Status, IsValid, ApiVersion, TableEnumOrId FROM ApexTrigger WHERE Name = '{}'",
227            safe_name
228        );
229        let mut triggers: Vec<ApexTrigger> = self.query_all(&soql).await?;
230        Ok(triggers.pop())
231    }
232
233    // =========================================================================
234    // Debug Log Operations
235    // =========================================================================
236
237    /// Get recent Apex logs.
238    ///
239    /// # Arguments
240    /// * `limit` - Maximum number of logs to return (defaults to 20)
241    #[instrument(skip(self))]
242    pub async fn get_apex_logs(&self, limit: Option<u32>) -> Result<Vec<ApexLog>> {
243        let limit = limit.unwrap_or(20);
244        let soql = format!(
245            "SELECT Id, LogUserId, LogUser.Name, LogLength, LastModifiedDate, StartTime, Status, Operation, Request, Application, DurationMilliseconds, Location FROM ApexLog ORDER BY LastModifiedDate DESC LIMIT {}",
246            limit
247        );
248        self.query_all(&soql).await
249    }
250
251    /// Get the body of a specific Apex log.
252    #[instrument(skip(self))]
253    pub async fn get_apex_log_body(&self, log_id: &str) -> Result<String> {
254        if !url_security::is_valid_salesforce_id(log_id) {
255            return Err(Error::new(ErrorKind::Salesforce {
256                error_code: "INVALID_ID".to_string(),
257                message: "Invalid Salesforce ID format".to_string(),
258            }));
259        }
260        let url = format!(
261            "{}/services/data/v{}/tooling/sobjects/ApexLog/{}/Body",
262            self.client.instance_url(),
263            self.client.api_version(),
264            log_id
265        );
266
267        let request = self.client.get(&url);
268        let response = self.client.execute(request).await?;
269        response.text().await.map_err(Into::into)
270    }
271
272    /// Delete an Apex log.
273    #[instrument(skip(self))]
274    pub async fn delete_apex_log(&self, log_id: &str) -> Result<()> {
275        if !url_security::is_valid_salesforce_id(log_id) {
276            return Err(Error::new(ErrorKind::Salesforce {
277                error_code: "INVALID_ID".to_string(),
278                message: "Invalid Salesforce ID format".to_string(),
279            }));
280        }
281        let url = format!(
282            "{}/services/data/v{}/tooling/sobjects/ApexLog/{}",
283            self.client.instance_url(),
284            self.client.api_version(),
285            log_id
286        );
287
288        let request = self.client.delete(&url);
289        let response = self.client.execute(request).await?;
290
291        if response.status() == 204 || response.is_success() {
292            Ok(())
293        } else {
294            Err(Error::new(ErrorKind::Salesforce {
295                error_code: "DELETE_FAILED".to_string(),
296                message: format!("Failed to delete log: status {}", response.status()),
297            }))
298        }
299    }
300
301    /// Delete all Apex logs for the current user.
302    #[instrument(skip(self))]
303    pub async fn delete_all_apex_logs(&self) -> Result<u32> {
304        let logs = self.get_apex_logs(Some(200)).await?;
305        let count = logs.len() as u32;
306
307        for log in logs {
308            self.delete_apex_log(&log.id).await?;
309        }
310
311        Ok(count)
312    }
313
314    // =========================================================================
315    // Code Coverage Operations
316    // =========================================================================
317
318    /// Get code coverage for all Apex classes and triggers.
319    #[instrument(skip(self))]
320    pub async fn get_code_coverage(&self) -> Result<Vec<ApexCodeCoverageAggregate>> {
321        self.query_all(
322            "SELECT Id, ApexClassOrTriggerId, ApexClassOrTrigger.Name, NumLinesCovered, NumLinesUncovered, Coverage FROM ApexCodeCoverageAggregate"
323        ).await
324    }
325
326    /// Get overall org-wide code coverage percentage.
327    #[instrument(skip(self))]
328    pub async fn get_org_wide_coverage(&self) -> Result<f64> {
329        let coverage = self.get_code_coverage().await?;
330
331        let mut total_covered = 0i64;
332        let mut total_uncovered = 0i64;
333
334        for item in coverage {
335            total_covered += item.num_lines_covered as i64;
336            total_uncovered += item.num_lines_uncovered as i64;
337        }
338
339        let total_lines = total_covered + total_uncovered;
340        if total_lines == 0 {
341            return Ok(0.0);
342        }
343
344        Ok((total_covered as f64 / total_lines as f64) * 100.0)
345    }
346
347    // =========================================================================
348    // Trace Flag Operations
349    // =========================================================================
350
351    /// Get all active trace flags.
352    #[instrument(skip(self))]
353    pub async fn get_trace_flags(&self) -> Result<Vec<TraceFlag>> {
354        self.query_all(
355            "SELECT Id, TracedEntityId, LogType, DebugLevelId, StartDate, ExpirationDate FROM TraceFlag"
356        ).await
357    }
358
359    /// Get all debug levels.
360    #[instrument(skip(self))]
361    pub async fn get_debug_levels(&self) -> Result<Vec<DebugLevel>> {
362        self.query_all(
363            "SELECT Id, DeveloperName, MasterLabel, ApexCode, ApexProfiling, Callout, Database, System, Validation, Visualforce, Workflow FROM DebugLevel"
364        ).await
365    }
366
367    // =========================================================================
368    // Generic SObject Operations (Tooling)
369    // =========================================================================
370
371    /// Get a Tooling API SObject by ID.
372    #[instrument(skip(self))]
373    pub async fn get<T: DeserializeOwned>(&self, sobject: &str, id: &str) -> Result<T> {
374        if !soql::is_safe_sobject_name(sobject) {
375            return Err(Error::new(ErrorKind::Salesforce {
376                error_code: "INVALID_SOBJECT".to_string(),
377                message: "Invalid SObject name".to_string(),
378            }));
379        }
380        if !url_security::is_valid_salesforce_id(id) {
381            return Err(Error::new(ErrorKind::Salesforce {
382                error_code: "INVALID_ID".to_string(),
383                message: "Invalid Salesforce ID format".to_string(),
384            }));
385        }
386        let path = format!("sobjects/{}/{}", sobject, id);
387        self.client.tooling_get(&path).await.map_err(Into::into)
388    }
389
390    /// Create a Tooling API SObject.
391    #[instrument(skip(self, record))]
392    pub async fn create<T: serde::Serialize>(&self, sobject: &str, record: &T) -> Result<String> {
393        if !soql::is_safe_sobject_name(sobject) {
394            return Err(Error::new(ErrorKind::Salesforce {
395                error_code: "INVALID_SOBJECT".to_string(),
396                message: "Invalid SObject name".to_string(),
397            }));
398        }
399        let path = format!("sobjects/{}", sobject);
400        let result: CreateResponse = self.client.tooling_post(&path, record).await?;
401
402        if result.success {
403            Ok(result.id)
404        } else {
405            Err(Error::new(ErrorKind::Salesforce {
406                error_code: "CREATE_FAILED".to_string(),
407                message: result
408                    .errors
409                    .into_iter()
410                    .map(|e| e.message)
411                    .collect::<Vec<_>>()
412                    .join("; "),
413            }))
414        }
415    }
416
417    /// Delete a Tooling API SObject.
418    #[instrument(skip(self))]
419    pub async fn delete(&self, sobject: &str, id: &str) -> Result<()> {
420        if !soql::is_safe_sobject_name(sobject) {
421            return Err(Error::new(ErrorKind::Salesforce {
422                error_code: "INVALID_SOBJECT".to_string(),
423                message: "Invalid SObject name".to_string(),
424            }));
425        }
426        if !url_security::is_valid_salesforce_id(id) {
427            return Err(Error::new(ErrorKind::Salesforce {
428                error_code: "INVALID_ID".to_string(),
429                message: "Invalid Salesforce ID format".to_string(),
430            }));
431        }
432        let url = format!(
433            "{}/services/data/v{}/tooling/sobjects/{}/{}",
434            self.client.instance_url(),
435            self.client.api_version(),
436            sobject,
437            id
438        );
439
440        let request = self.client.delete(&url);
441        let response = self.client.execute(request).await?;
442
443        if response.status() == 204 || response.is_success() {
444            Ok(())
445        } else {
446            Err(Error::new(ErrorKind::Salesforce {
447                error_code: "DELETE_FAILED".to_string(),
448                message: format!("Failed to delete {}: status {}", sobject, response.status()),
449            }))
450        }
451    }
452}
453
454/// Response from create operations.
455#[derive(Debug, Clone, serde::Deserialize)]
456struct CreateResponse {
457    id: String,
458    success: bool,
459    #[serde(default)]
460    errors: Vec<CreateError>,
461}
462
463#[derive(Debug, Clone, serde::Deserialize)]
464struct CreateError {
465    message: String,
466    #[serde(rename = "statusCode")]
467    #[allow(dead_code)]
468    status_code: String,
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474
475    #[test]
476    fn test_client_creation() {
477        let client = ToolingClient::new("https://na1.salesforce.com", "token123").unwrap();
478
479        assert_eq!(client.instance_url(), "https://na1.salesforce.com");
480        assert_eq!(client.api_version(), "62.0");
481    }
482
483    #[test]
484    fn test_api_version_override() {
485        let client = ToolingClient::new("https://na1.salesforce.com", "token")
486            .unwrap()
487            .with_api_version("60.0");
488
489        assert_eq!(client.api_version(), "60.0");
490    }
491}