1use serde::{de::DeserializeOwned, Serialize};
12use tracing::instrument;
13
14use crate::client::SfHttpClient;
15use crate::config::ClientConfig;
16use crate::error::{Error, ErrorKind, Result};
17use crate::request::RequestBuilder;
18use crate::DEFAULT_API_VERSION;
19
20#[derive(Clone)]
49pub struct SalesforceClient {
50 http: SfHttpClient,
51 instance_url: String,
52 access_token: String,
53 api_version: String,
54}
55
56impl std::fmt::Debug for SalesforceClient {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("SalesforceClient")
59 .field("instance_url", &self.instance_url)
60 .field("access_token", &"[REDACTED]")
61 .field("api_version", &self.api_version)
62 .finish_non_exhaustive()
63 }
64}
65
66impl SalesforceClient {
67 pub fn new(instance_url: impl Into<String>, access_token: impl Into<String>) -> Result<Self> {
69 Self::with_config(instance_url, access_token, ClientConfig::default())
70 }
71
72 pub fn with_config(
74 instance_url: impl Into<String>,
75 access_token: impl Into<String>,
76 config: ClientConfig,
77 ) -> Result<Self> {
78 let http = SfHttpClient::new(config)?;
79 Ok(Self {
80 http,
81 instance_url: instance_url.into().trim_end_matches('/').to_string(),
82 access_token: access_token.into(),
83 api_version: DEFAULT_API_VERSION.to_string(),
84 })
85 }
86
87 pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
89 self.api_version = version.into();
90 self
91 }
92
93 pub fn instance_url(&self) -> &str {
95 &self.instance_url
96 }
97
98 pub fn access_token(&self) -> &str {
100 &self.access_token
101 }
102
103 pub fn api_version(&self) -> &str {
105 &self.api_version
106 }
107
108 pub fn url(&self, path: &str) -> String {
113 if path.starts_with("http://") || path.starts_with("https://") {
114 path.to_string()
115 } else if path.starts_with('/') {
116 format!("{}{}", self.instance_url, path)
117 } else {
118 format!("{}/{}", self.instance_url, path)
119 }
120 }
121
122 pub fn rest_url(&self, path: &str) -> String {
126 let path = path.trim_start_matches('/');
127 format!(
128 "{}/services/data/v{}/{}",
129 self.instance_url, self.api_version, path
130 )
131 }
132
133 pub fn tooling_url(&self, path: &str) -> String {
137 let path = path.trim_start_matches('/');
138 format!(
139 "{}/services/data/v{}/tooling/{}",
140 self.instance_url, self.api_version, path
141 )
142 }
143
144 pub fn metadata_url(&self) -> String {
146 format!("{}/services/Soap/m/{}", self.instance_url, self.api_version)
147 }
148
149 pub fn bulk_url(&self, path: &str) -> String {
151 let path = path.trim_start_matches('/');
152 format!(
153 "{}/services/data/v{}/jobs/{}",
154 self.instance_url, self.api_version, path
155 )
156 }
157
158 pub fn get(&self, url: &str) -> RequestBuilder {
164 self.http.get(url).bearer_auth(&self.access_token)
165 }
166
167 pub fn post(&self, url: &str) -> RequestBuilder {
169 self.http.post(url).bearer_auth(&self.access_token)
170 }
171
172 pub fn patch(&self, url: &str) -> RequestBuilder {
174 self.http.patch(url).bearer_auth(&self.access_token)
175 }
176
177 pub fn put(&self, url: &str) -> RequestBuilder {
179 self.http.put(url).bearer_auth(&self.access_token)
180 }
181
182 pub fn delete(&self, url: &str) -> RequestBuilder {
184 self.http.delete(url).bearer_auth(&self.access_token)
185 }
186
187 pub async fn execute(&self, request: RequestBuilder) -> Result<crate::Response> {
189 self.http.execute(request).await
190 }
191
192 #[instrument(skip(self), fields(url = %url))]
198 pub async fn get_json<T: DeserializeOwned>(&self, url: &str) -> Result<T> {
199 let full_url = self.url(url);
200 let request = self.get(&full_url).header("Accept", "application/json");
201 let response = self.http.execute(request).await?;
202 response.json().await
203 }
204
205 pub async fn rest_get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
207 self.get_json(&self.rest_url(path)).await
208 }
209
210 pub async fn tooling_get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
212 self.get_json(&self.tooling_url(path)).await
213 }
214
215 #[instrument(skip(self, body), fields(url = %url))]
217 pub async fn post_json<T: DeserializeOwned, B: Serialize>(
218 &self,
219 url: &str,
220 body: &B,
221 ) -> Result<T> {
222 let full_url = self.url(url);
223 let request = self
224 .post(&full_url)
225 .header("Accept", "application/json")
226 .json(body)?;
227 let response = self.http.execute(request).await?;
228 response.json().await
229 }
230
231 pub async fn rest_post<T: DeserializeOwned, B: Serialize>(
233 &self,
234 path: &str,
235 body: &B,
236 ) -> Result<T> {
237 self.post_json(&self.rest_url(path), body).await
238 }
239
240 pub async fn tooling_post<T: DeserializeOwned, B: Serialize>(
242 &self,
243 path: &str,
244 body: &B,
245 ) -> Result<T> {
246 self.post_json(&self.tooling_url(path), body).await
247 }
248
249 #[instrument(skip(self, body), fields(url = %url))]
251 pub async fn patch_json<B: Serialize>(&self, url: &str, body: &B) -> Result<()> {
252 let full_url = self.url(url);
253 let request = self.patch(&full_url).json(body)?;
254 let response = self.http.execute(request).await?;
255
256 if response.status() == 204 || response.is_success() {
258 Ok(())
259 } else {
260 Err(Error::new(ErrorKind::Http {
261 status: response.status(),
262 message: "PATCH request failed".to_string(),
263 }))
264 }
265 }
266
267 pub async fn rest_patch<B: Serialize>(&self, path: &str, body: &B) -> Result<()> {
269 self.patch_json(&self.rest_url(path), body).await
270 }
271
272 #[instrument(skip(self), fields(url = %url))]
274 pub async fn delete_request(&self, url: &str) -> Result<()> {
275 let full_url = self.url(url);
276 let request = self.delete(&full_url);
277 let response = self.http.execute(request).await?;
278
279 if response.status() == 204 || response.is_success() {
281 Ok(())
282 } else {
283 Err(Error::new(ErrorKind::Http {
284 status: response.status(),
285 message: "DELETE request failed".to_string(),
286 }))
287 }
288 }
289
290 pub async fn rest_delete(&self, path: &str) -> Result<()> {
292 self.delete_request(&self.rest_url(path)).await
293 }
294
295 pub async fn get_json_if_changed<T: DeserializeOwned>(
302 &self,
303 url: &str,
304 etag: &str,
305 ) -> Result<Option<(T, Option<String>)>> {
306 let full_url = self.url(url);
307 let request = self.get(&full_url).if_none_match(etag);
308 let response = self.http.execute(request).await?;
309
310 if response.is_not_modified() {
311 return Ok(None);
312 }
313
314 let new_etag = response.etag().map(|s| s.to_string());
315 let data: T = response.json().await?;
316 Ok(Some((data, new_etag)))
317 }
318
319 pub async fn get_json_if_modified<T: DeserializeOwned>(
322 &self,
323 url: &str,
324 since: &str,
325 ) -> Result<Option<(T, Option<String>)>> {
326 let full_url = self.url(url);
327 let request = self.get(&full_url).if_modified_since(since);
328 let response = self.http.execute(request).await?;
329
330 if response.is_not_modified() {
331 return Ok(None);
332 }
333
334 let last_modified = response.last_modified().map(|s| s.to_string());
335 let data: T = response.json().await?;
336 Ok(Some((data, last_modified)))
337 }
338
339 pub async fn query<T: DeserializeOwned>(&self, soql: &str) -> Result<QueryResult<T>> {
345 let encoded = urlencoding::encode(soql);
346 let url = format!(
347 "{}/services/data/v{}/query?q={}",
348 self.instance_url, self.api_version, encoded
349 );
350 self.get_json(&url).await
351 }
352
353 pub async fn tooling_query<T: DeserializeOwned>(&self, soql: &str) -> Result<QueryResult<T>> {
355 let encoded = urlencoding::encode(soql);
356 let url = format!(
357 "{}/services/data/v{}/tooling/query?q={}",
358 self.instance_url, self.api_version, encoded
359 );
360 self.get_json(&url).await
361 }
362
363 pub async fn query_all<T: DeserializeOwned + Clone>(&self, soql: &str) -> Result<Vec<T>> {
365 let mut all_records = Vec::new();
366 let mut result: QueryResult<T> = self.query(soql).await?;
367
368 all_records.extend(result.records);
369
370 while let Some(ref next_url) = result.next_records_url {
371 result = self.get_json(next_url).await?;
372 all_records.extend(result.records);
373 }
374
375 Ok(all_records)
376 }
377
378 pub async fn tooling_query_all<T: DeserializeOwned + Clone>(
380 &self,
381 soql: &str,
382 ) -> Result<Vec<T>> {
383 let mut all_records = Vec::new();
384 let mut result: QueryResult<T> = self.tooling_query(soql).await?;
385
386 all_records.extend(result.records);
387
388 while let Some(ref next_url) = result.next_records_url {
389 result = self.get_json(next_url).await?;
390 all_records.extend(result.records);
391 }
392
393 Ok(all_records)
394 }
395}
396
397#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
399pub struct QueryResult<T> {
400 #[serde(rename = "totalSize")]
402 pub total_size: u64,
403
404 pub done: bool,
406
407 #[serde(rename = "nextRecordsUrl")]
409 pub next_records_url: Option<String>,
410
411 pub records: Vec<T>,
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
420 fn test_url_building() {
421 let client = SalesforceClient::new("https://na1.salesforce.com", "token123").unwrap();
422
423 assert_eq!(
425 client.url("/services/oauth2/userinfo"),
426 "https://na1.salesforce.com/services/oauth2/userinfo"
427 );
428
429 assert_eq!(
431 client.url("services/oauth2/userinfo"),
432 "https://na1.salesforce.com/services/oauth2/userinfo"
433 );
434
435 assert_eq!(
437 client.url("https://other.com/path"),
438 "https://other.com/path"
439 );
440
441 assert_eq!(
443 client.rest_url("sobjects/Account"),
444 "https://na1.salesforce.com/services/data/v62.0/sobjects/Account"
445 );
446
447 assert_eq!(
449 client.tooling_url("sobjects/ApexClass"),
450 "https://na1.salesforce.com/services/data/v62.0/tooling/sobjects/ApexClass"
451 );
452
453 assert_eq!(
455 client.bulk_url("ingest"),
456 "https://na1.salesforce.com/services/data/v62.0/jobs/ingest"
457 );
458 }
459
460 #[test]
461 fn test_api_version() {
462 let client = SalesforceClient::new("https://na1.salesforce.com", "token")
463 .unwrap()
464 .with_api_version("60.0");
465
466 assert_eq!(client.api_version(), "60.0");
467 assert_eq!(
468 client.rest_url("limits"),
469 "https://na1.salesforce.com/services/data/v60.0/limits"
470 );
471 }
472
473 #[test]
474 fn test_trailing_slash_handling() {
475 let client = SalesforceClient::new(
476 "https://na1.salesforce.com/", "token",
478 )
479 .unwrap();
480
481 assert_eq!(client.instance_url(), "https://na1.salesforce.com");
482 assert_eq!(
483 client.rest_url("limits"),
484 "https://na1.salesforce.com/services/data/v62.0/limits"
485 );
486 }
487}