1use serde::{de::DeserializeOwned, Serialize};
12use tracing::instrument;
13
14use crate::client::SfHttpClient;
15use crate::config::ClientConfig;
16use crate::error::{Error, ErrorKind, Result};
17use crate::request::RequestBuilder;
18use crate::DEFAULT_API_VERSION;
19
20#[derive(Clone)]
49pub struct SalesforceClient {
50 http: SfHttpClient,
51 instance_url: String,
52 access_token: String,
53 api_version: String,
54}
55
56impl std::fmt::Debug for SalesforceClient {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("SalesforceClient")
59 .field("instance_url", &self.instance_url)
60 .field("access_token", &"[REDACTED]")
61 .field("api_version", &self.api_version)
62 .finish_non_exhaustive()
63 }
64}
65
66impl SalesforceClient {
67 pub fn new(instance_url: impl Into<String>, access_token: impl Into<String>) -> Result<Self> {
69 Self::with_config(instance_url, access_token, ClientConfig::default())
70 }
71
72 pub fn with_config(
74 instance_url: impl Into<String>,
75 access_token: impl Into<String>,
76 config: ClientConfig,
77 ) -> Result<Self> {
78 let http = SfHttpClient::new(config)?;
79 Ok(Self {
80 http,
81 instance_url: instance_url.into().trim_end_matches('/').to_string(),
82 access_token: access_token.into(),
83 api_version: DEFAULT_API_VERSION.to_string(),
84 })
85 }
86
87 pub fn with_api_version(mut self, version: impl Into<String>) -> Self {
89 self.api_version = version.into();
90 self
91 }
92
93 pub fn instance_url(&self) -> &str {
95 &self.instance_url
96 }
97
98 pub fn access_token(&self) -> &str {
100 &self.access_token
101 }
102
103 pub fn api_version(&self) -> &str {
105 &self.api_version
106 }
107
108 pub fn url(&self, path: &str) -> String {
113 if path.starts_with("http://") || path.starts_with("https://") {
114 path.to_string()
115 } else if path.starts_with('/') {
116 format!("{}{}", self.instance_url, path)
117 } else {
118 format!("{}/{}", self.instance_url, path)
119 }
120 }
121
122 pub fn rest_url(&self, path: &str) -> String {
126 let path = path.trim_start_matches('/');
127 format!(
128 "{}/services/data/v{}/{}",
129 self.instance_url, self.api_version, path
130 )
131 }
132
133 pub fn tooling_url(&self, path: &str) -> String {
137 let path = path.trim_start_matches('/');
138 format!(
139 "{}/services/data/v{}/tooling/{}",
140 self.instance_url, self.api_version, path
141 )
142 }
143
144 pub fn metadata_url(&self) -> String {
146 format!("{}/services/Soap/m/{}", self.instance_url, self.api_version)
147 }
148
149 pub fn bulk_url(&self, path: &str) -> String {
151 let path = path.trim_start_matches('/');
152 format!(
153 "{}/services/data/v{}/jobs/{}",
154 self.instance_url, self.api_version, path
155 )
156 }
157
158 pub fn get(&self, url: &str) -> RequestBuilder {
164 self.http.get(url).bearer_auth(&self.access_token)
165 }
166
167 pub fn post(&self, url: &str) -> RequestBuilder {
169 self.http.post(url).bearer_auth(&self.access_token)
170 }
171
172 pub fn patch(&self, url: &str) -> RequestBuilder {
174 self.http.patch(url).bearer_auth(&self.access_token)
175 }
176
177 pub fn put(&self, url: &str) -> RequestBuilder {
179 self.http.put(url).bearer_auth(&self.access_token)
180 }
181
182 pub fn delete(&self, url: &str) -> RequestBuilder {
184 self.http.delete(url).bearer_auth(&self.access_token)
185 }
186
187 pub async fn execute(&self, request: RequestBuilder) -> Result<crate::Response> {
189 self.http.execute(request).await
190 }
191
192 #[instrument(skip(self), fields(url = %url))]
198 pub async fn get_json<T: DeserializeOwned>(&self, url: &str) -> Result<T> {
199 let full_url = self.url(url);
200 let request = self.get(&full_url);
201 let response = self.http.execute(request).await?;
202 response.json().await
203 }
204
205 pub async fn rest_get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
207 self.get_json(&self.rest_url(path)).await
208 }
209
210 pub async fn tooling_get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
212 self.get_json(&self.tooling_url(path)).await
213 }
214
215 #[instrument(skip(self, body), fields(url = %url))]
217 pub async fn post_json<T: DeserializeOwned, B: Serialize>(
218 &self,
219 url: &str,
220 body: &B,
221 ) -> Result<T> {
222 let full_url = self.url(url);
223 let request = self.post(&full_url).json(body)?;
224 let response = self.http.execute(request).await?;
225 response.json().await
226 }
227
228 pub async fn rest_post<T: DeserializeOwned, B: Serialize>(
230 &self,
231 path: &str,
232 body: &B,
233 ) -> Result<T> {
234 self.post_json(&self.rest_url(path), body).await
235 }
236
237 pub async fn tooling_post<T: DeserializeOwned, B: Serialize>(
239 &self,
240 path: &str,
241 body: &B,
242 ) -> Result<T> {
243 self.post_json(&self.tooling_url(path), body).await
244 }
245
246 #[instrument(skip(self, body), fields(url = %url))]
248 pub async fn patch_json<B: Serialize>(&self, url: &str, body: &B) -> Result<()> {
249 let full_url = self.url(url);
250 let request = self.patch(&full_url).json(body)?;
251 let response = self.http.execute(request).await?;
252
253 if response.status() == 204 || response.is_success() {
255 Ok(())
256 } else {
257 Err(Error::new(ErrorKind::Http {
258 status: response.status(),
259 message: "PATCH request failed".to_string(),
260 }))
261 }
262 }
263
264 pub async fn rest_patch<B: Serialize>(&self, path: &str, body: &B) -> Result<()> {
266 self.patch_json(&self.rest_url(path), body).await
267 }
268
269 #[instrument(skip(self), fields(url = %url))]
271 pub async fn delete_request(&self, url: &str) -> Result<()> {
272 let full_url = self.url(url);
273 let request = self.delete(&full_url);
274 let response = self.http.execute(request).await?;
275
276 if response.status() == 204 || response.is_success() {
278 Ok(())
279 } else {
280 Err(Error::new(ErrorKind::Http {
281 status: response.status(),
282 message: "DELETE request failed".to_string(),
283 }))
284 }
285 }
286
287 pub async fn rest_delete(&self, path: &str) -> Result<()> {
289 self.delete_request(&self.rest_url(path)).await
290 }
291
292 pub async fn get_json_if_changed<T: DeserializeOwned>(
299 &self,
300 url: &str,
301 etag: &str,
302 ) -> Result<Option<(T, Option<String>)>> {
303 let full_url = self.url(url);
304 let request = self.get(&full_url).if_none_match(etag);
305 let response = self.http.execute(request).await?;
306
307 if response.is_not_modified() {
308 return Ok(None);
309 }
310
311 let new_etag = response.etag().map(|s| s.to_string());
312 let data: T = response.json().await?;
313 Ok(Some((data, new_etag)))
314 }
315
316 pub async fn get_json_if_modified<T: DeserializeOwned>(
319 &self,
320 url: &str,
321 since: &str,
322 ) -> Result<Option<(T, Option<String>)>> {
323 let full_url = self.url(url);
324 let request = self.get(&full_url).if_modified_since(since);
325 let response = self.http.execute(request).await?;
326
327 if response.is_not_modified() {
328 return Ok(None);
329 }
330
331 let last_modified = response.last_modified().map(|s| s.to_string());
332 let data: T = response.json().await?;
333 Ok(Some((data, last_modified)))
334 }
335
336 pub async fn query<T: DeserializeOwned>(&self, soql: &str) -> Result<QueryResult<T>> {
342 let encoded = urlencoding::encode(soql);
343 let url = format!(
344 "{}/services/data/v{}/query?q={}",
345 self.instance_url, self.api_version, encoded
346 );
347 self.get_json(&url).await
348 }
349
350 pub async fn tooling_query<T: DeserializeOwned>(&self, soql: &str) -> Result<QueryResult<T>> {
352 let encoded = urlencoding::encode(soql);
353 let url = format!(
354 "{}/services/data/v{}/tooling/query?q={}",
355 self.instance_url, self.api_version, encoded
356 );
357 self.get_json(&url).await
358 }
359
360 pub async fn query_all<T: DeserializeOwned + Clone>(&self, soql: &str) -> Result<Vec<T>> {
362 let mut all_records = Vec::new();
363 let mut result: QueryResult<T> = self.query(soql).await?;
364
365 all_records.extend(result.records);
366
367 while let Some(ref next_url) = result.next_records_url {
368 result = self.get_json(next_url).await?;
369 all_records.extend(result.records);
370 }
371
372 Ok(all_records)
373 }
374
375 pub async fn tooling_query_all<T: DeserializeOwned + Clone>(
377 &self,
378 soql: &str,
379 ) -> Result<Vec<T>> {
380 let mut all_records = Vec::new();
381 let mut result: QueryResult<T> = self.tooling_query(soql).await?;
382
383 all_records.extend(result.records);
384
385 while let Some(ref next_url) = result.next_records_url {
386 result = self.get_json(next_url).await?;
387 all_records.extend(result.records);
388 }
389
390 Ok(all_records)
391 }
392}
393
394#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
396pub struct QueryResult<T> {
397 #[serde(rename = "totalSize")]
399 pub total_size: u64,
400
401 pub done: bool,
403
404 #[serde(rename = "nextRecordsUrl")]
406 pub next_records_url: Option<String>,
407
408 pub records: Vec<T>,
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_url_building() {
418 let client = SalesforceClient::new("https://na1.salesforce.com", "token123").unwrap();
419
420 assert_eq!(
422 client.url("/services/oauth2/userinfo"),
423 "https://na1.salesforce.com/services/oauth2/userinfo"
424 );
425
426 assert_eq!(
428 client.url("services/oauth2/userinfo"),
429 "https://na1.salesforce.com/services/oauth2/userinfo"
430 );
431
432 assert_eq!(
434 client.url("https://other.com/path"),
435 "https://other.com/path"
436 );
437
438 assert_eq!(
440 client.rest_url("sobjects/Account"),
441 "https://na1.salesforce.com/services/data/v62.0/sobjects/Account"
442 );
443
444 assert_eq!(
446 client.tooling_url("sobjects/ApexClass"),
447 "https://na1.salesforce.com/services/data/v62.0/tooling/sobjects/ApexClass"
448 );
449
450 assert_eq!(
452 client.bulk_url("ingest"),
453 "https://na1.salesforce.com/services/data/v62.0/jobs/ingest"
454 );
455 }
456
457 #[test]
458 fn test_api_version() {
459 let client = SalesforceClient::new("https://na1.salesforce.com", "token")
460 .unwrap()
461 .with_api_version("60.0");
462
463 assert_eq!(client.api_version(), "60.0");
464 assert_eq!(
465 client.rest_url("limits"),
466 "https://na1.salesforce.com/services/data/v60.0/limits"
467 );
468 }
469
470 #[test]
471 fn test_trailing_slash_handling() {
472 let client = SalesforceClient::new(
473 "https://na1.salesforce.com/", "token",
475 )
476 .unwrap();
477
478 assert_eq!(client.instance_url(), "https://na1.salesforce.com");
479 assert_eq!(
480 client.rest_url("limits"),
481 "https://na1.salesforce.com/services/data/v62.0/limits"
482 );
483 }
484}