1use serde::de::DeserializeOwned;
7use tracing::instrument;
8
9use busbar_sf_client::security::{soql, url as url_security};
10use busbar_sf_client::{ClientConfig, QueryResult, SalesforceClient};
11
12use crate::error::{Error, ErrorKind, Result};
13use crate::types::*;
14
15#[derive(Debug, Clone)]
42pub struct ToolingClient {
43 client: SalesforceClient,
44}
45
46impl ToolingClient {
47 pub fn new(instance_url: impl Into<String>, access_token: impl Into<String>) -> Result<Self> {
49 let client = SalesforceClient::new(instance_url, access_token)?;
50 Ok(Self { client })
51 }
52
53 pub fn with_config(
55 instance_url: impl Into<String>,
56 access_token: impl Into<String>,
57 config: ClientConfig,
58 ) -> Result<Self> {
59 let client = SalesforceClient::with_config(instance_url, access_token, config)?;
60 Ok(Self { client })
61 }
62
63 pub fn from_client(client: SalesforceClient) -> Self {
65 Self { client }
66 }
67
68 pub fn inner(&self) -> &SalesforceClient {
70 &self.client
71 }
72
73 pub fn instance_url(&self) -> &str {
75 self.client.instance_url()
76 }
77
78 pub fn api_version(&self) -> &str {
80 self.client.api_version()
81 }
82
83 pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
85 self.client = self.client.with_api_version(version);
86 self
87 }
88
89 #[instrument(skip(self))]
110 pub async fn query<T: DeserializeOwned>(&self, soql: &str) -> Result<QueryResult<T>> {
111 self.client.tooling_query(soql).await.map_err(Into::into)
112 }
113
114 #[instrument(skip(self))]
121 pub async fn query_all<T: DeserializeOwned + Clone>(&self, soql: &str) -> Result<Vec<T>> {
122 self.client
123 .tooling_query_all(soql)
124 .await
125 .map_err(Into::into)
126 }
127
128 #[instrument(skip(self))]
145 pub async fn execute_anonymous(&self, apex_code: &str) -> Result<ExecuteAnonymousResult> {
146 let encoded = urlencoding::encode(apex_code);
147 let url = format!(
148 "{}/services/data/v{}/tooling/executeAnonymous/?anonymousBody={}",
149 self.client.instance_url(),
150 self.client.api_version(),
151 encoded
152 );
153
154 let result: ExecuteAnonymousResult = self.client.get_json(&url).await?;
155
156 if !result.compiled {
158 if let Some(ref problem) = result.compile_problem {
159 return Err(Error::new(ErrorKind::ApexCompilation(problem.clone())));
160 }
161 }
162
163 if !result.success {
164 if let Some(ref message) = result.exception_message {
165 return Err(Error::new(ErrorKind::ApexExecution(message.clone())));
166 }
167 }
168
169 Ok(result)
170 }
171
172 #[instrument(skip(self))]
178 pub async fn get_apex_classes(&self) -> Result<Vec<ApexClass>> {
179 self.query_all("SELECT Id, Name, Body, Status, IsValid, ApiVersion, NamespacePrefix, CreatedDate, LastModifiedDate FROM ApexClass")
180 .await
181 }
182
183 #[instrument(skip(self))]
185 pub async fn get_apex_class_by_name(&self, name: &str) -> Result<Option<ApexClass>> {
186 let safe_name = soql::escape_string(name);
187 let soql = format!(
188 "SELECT Id, Name, Body, Status, IsValid, ApiVersion, NamespacePrefix, CreatedDate, LastModifiedDate FROM ApexClass WHERE Name = '{}'",
189 safe_name
190 );
191 let mut classes: Vec<ApexClass> = self.query_all(&soql).await?;
192 Ok(classes.pop())
193 }
194
195 #[instrument(skip(self))]
197 pub async fn get_apex_class(&self, id: &str) -> Result<ApexClass> {
198 if !url_security::is_valid_salesforce_id(id) {
199 return Err(Error::new(ErrorKind::Salesforce {
200 error_code: "INVALID_ID".to_string(),
201 message: "Invalid Salesforce ID format".to_string(),
202 }));
203 }
204 let path = format!("sobjects/ApexClass/{}", id);
205 self.client.tooling_get(&path).await.map_err(Into::into)
206 }
207
208 #[instrument(skip(self))]
214 pub async fn get_apex_triggers(&self) -> Result<Vec<ApexTrigger>> {
215 self.query_all(
216 "SELECT Id, Name, Body, Status, IsValid, ApiVersion, TableEnumOrId FROM ApexTrigger",
217 )
218 .await
219 }
220
221 #[instrument(skip(self))]
223 pub async fn get_apex_trigger_by_name(&self, name: &str) -> Result<Option<ApexTrigger>> {
224 let safe_name = soql::escape_string(name);
225 let soql = format!(
226 "SELECT Id, Name, Body, Status, IsValid, ApiVersion, TableEnumOrId FROM ApexTrigger WHERE Name = '{}'",
227 safe_name
228 );
229 let mut triggers: Vec<ApexTrigger> = self.query_all(&soql).await?;
230 Ok(triggers.pop())
231 }
232
233 #[instrument(skip(self))]
242 pub async fn get_apex_logs(&self, limit: Option<u32>) -> Result<Vec<ApexLog>> {
243 let limit = limit.unwrap_or(20);
244 let soql = format!(
245 "SELECT Id, LogUserId, LogUser.Name, LogLength, LastModifiedDate, StartTime, Status, Operation, Request, Application, DurationMilliseconds, Location FROM ApexLog ORDER BY LastModifiedDate DESC LIMIT {}",
246 limit
247 );
248 self.query_all(&soql).await
249 }
250
251 #[instrument(skip(self))]
253 pub async fn get_apex_log_body(&self, log_id: &str) -> Result<String> {
254 if !url_security::is_valid_salesforce_id(log_id) {
255 return Err(Error::new(ErrorKind::Salesforce {
256 error_code: "INVALID_ID".to_string(),
257 message: "Invalid Salesforce ID format".to_string(),
258 }));
259 }
260 let url = format!(
261 "{}/services/data/v{}/tooling/sobjects/ApexLog/{}/Body",
262 self.client.instance_url(),
263 self.client.api_version(),
264 log_id
265 );
266
267 let request = self.client.get(&url);
268 let response = self.client.execute(request).await?;
269 response.text().await.map_err(Into::into)
270 }
271
272 #[instrument(skip(self))]
274 pub async fn delete_apex_log(&self, log_id: &str) -> Result<()> {
275 if !url_security::is_valid_salesforce_id(log_id) {
276 return Err(Error::new(ErrorKind::Salesforce {
277 error_code: "INVALID_ID".to_string(),
278 message: "Invalid Salesforce ID format".to_string(),
279 }));
280 }
281 let url = format!(
282 "{}/services/data/v{}/tooling/sobjects/ApexLog/{}",
283 self.client.instance_url(),
284 self.client.api_version(),
285 log_id
286 );
287
288 let request = self.client.delete(&url);
289 let response = self.client.execute(request).await?;
290
291 if response.status() == 204 || response.is_success() {
292 Ok(())
293 } else {
294 Err(Error::new(ErrorKind::Salesforce {
295 error_code: "DELETE_FAILED".to_string(),
296 message: format!("Failed to delete log: status {}", response.status()),
297 }))
298 }
299 }
300
301 #[instrument(skip(self))]
303 pub async fn delete_all_apex_logs(&self) -> Result<u32> {
304 let logs = self.get_apex_logs(Some(200)).await?;
305 let count = logs.len() as u32;
306
307 for log in logs {
308 self.delete_apex_log(&log.id).await?;
309 }
310
311 Ok(count)
312 }
313
314 #[instrument(skip(self))]
320 pub async fn get_code_coverage(&self) -> Result<Vec<ApexCodeCoverageAggregate>> {
321 self.query_all(
322 "SELECT Id, ApexClassOrTriggerId, ApexClassOrTrigger.Name, NumLinesCovered, NumLinesUncovered, Coverage FROM ApexCodeCoverageAggregate"
323 ).await
324 }
325
326 #[instrument(skip(self))]
328 pub async fn get_org_wide_coverage(&self) -> Result<f64> {
329 let coverage = self.get_code_coverage().await?;
330
331 let mut total_covered = 0i64;
332 let mut total_uncovered = 0i64;
333
334 for item in coverage {
335 total_covered += item.num_lines_covered as i64;
336 total_uncovered += item.num_lines_uncovered as i64;
337 }
338
339 let total_lines = total_covered + total_uncovered;
340 if total_lines == 0 {
341 return Ok(0.0);
342 }
343
344 Ok((total_covered as f64 / total_lines as f64) * 100.0)
345 }
346
347 #[instrument(skip(self))]
353 pub async fn get_trace_flags(&self) -> Result<Vec<TraceFlag>> {
354 self.query_all(
355 "SELECT Id, TracedEntityId, LogType, DebugLevelId, StartDate, ExpirationDate FROM TraceFlag"
356 ).await
357 }
358
359 #[instrument(skip(self))]
361 pub async fn get_debug_levels(&self) -> Result<Vec<DebugLevel>> {
362 self.query_all(
363 "SELECT Id, DeveloperName, MasterLabel, ApexCode, ApexProfiling, Callout, Database, System, Validation, Visualforce, Workflow FROM DebugLevel"
364 ).await
365 }
366
367 #[instrument(skip(self))]
373 pub async fn get<T: DeserializeOwned>(&self, sobject: &str, id: &str) -> Result<T> {
374 if !soql::is_safe_sobject_name(sobject) {
375 return Err(Error::new(ErrorKind::Salesforce {
376 error_code: "INVALID_SOBJECT".to_string(),
377 message: "Invalid SObject name".to_string(),
378 }));
379 }
380 if !url_security::is_valid_salesforce_id(id) {
381 return Err(Error::new(ErrorKind::Salesforce {
382 error_code: "INVALID_ID".to_string(),
383 message: "Invalid Salesforce ID format".to_string(),
384 }));
385 }
386 let path = format!("sobjects/{}/{}", sobject, id);
387 self.client.tooling_get(&path).await.map_err(Into::into)
388 }
389
390 #[instrument(skip(self, record))]
392 pub async fn create<T: serde::Serialize>(&self, sobject: &str, record: &T) -> Result<String> {
393 if !soql::is_safe_sobject_name(sobject) {
394 return Err(Error::new(ErrorKind::Salesforce {
395 error_code: "INVALID_SOBJECT".to_string(),
396 message: "Invalid SObject name".to_string(),
397 }));
398 }
399 let path = format!("sobjects/{}", sobject);
400 let result: CreateResponse = self.client.tooling_post(&path, record).await?;
401
402 if result.success {
403 Ok(result.id)
404 } else {
405 Err(Error::new(ErrorKind::Salesforce {
406 error_code: "CREATE_FAILED".to_string(),
407 message: result
408 .errors
409 .into_iter()
410 .map(|e| e.message)
411 .collect::<Vec<_>>()
412 .join("; "),
413 }))
414 }
415 }
416
417 #[instrument(skip(self))]
419 pub async fn delete(&self, sobject: &str, id: &str) -> Result<()> {
420 if !soql::is_safe_sobject_name(sobject) {
421 return Err(Error::new(ErrorKind::Salesforce {
422 error_code: "INVALID_SOBJECT".to_string(),
423 message: "Invalid SObject name".to_string(),
424 }));
425 }
426 if !url_security::is_valid_salesforce_id(id) {
427 return Err(Error::new(ErrorKind::Salesforce {
428 error_code: "INVALID_ID".to_string(),
429 message: "Invalid Salesforce ID format".to_string(),
430 }));
431 }
432 let url = format!(
433 "{}/services/data/v{}/tooling/sobjects/{}/{}",
434 self.client.instance_url(),
435 self.client.api_version(),
436 sobject,
437 id
438 );
439
440 let request = self.client.delete(&url);
441 let response = self.client.execute(request).await?;
442
443 if response.status() == 204 || response.is_success() {
444 Ok(())
445 } else {
446 Err(Error::new(ErrorKind::Salesforce {
447 error_code: "DELETE_FAILED".to_string(),
448 message: format!("Failed to delete {}: status {}", sobject, response.status()),
449 }))
450 }
451 }
452}
453
454#[derive(Debug, Clone, serde::Deserialize)]
456struct CreateResponse {
457 id: String,
458 success: bool,
459 #[serde(default)]
460 errors: Vec<CreateError>,
461}
462
463#[derive(Debug, Clone, serde::Deserialize)]
464struct CreateError {
465 message: String,
466 #[serde(rename = "statusCode")]
467 #[allow(dead_code)]
468 status_code: String,
469}
470
471#[cfg(test)]
472mod tests {
473 use super::*;
474
475 #[test]
476 fn test_client_creation() {
477 let client = ToolingClient::new("https://na1.salesforce.com", "token123").unwrap();
478
479 assert_eq!(client.instance_url(), "https://na1.salesforce.com");
480 assert_eq!(client.api_version(), "62.0");
481 }
482
483 #[test]
484 fn test_api_version_override() {
485 let client = ToolingClient::new("https://na1.salesforce.com", "token")
486 .unwrap()
487 .with_api_version("60.0");
488
489 assert_eq!(client.api_version(), "60.0");
490 }
491}