use crate::api::rest_operation::RestOperation;
use crate::auth::Authenticator;
#[cfg(feature = "composite")]
use crate::client::ForceClient;
use crate::error::Result;
use serde::de::DeserializeOwned;
use serde_json::Value;
#[derive(Debug, Clone)]
pub enum BatchOp {
Update(String, String, Value),
Delete(String, String),
Create(String, Value),
}
#[derive(Debug, Default, Clone, Copy)]
pub struct BatchStats {
pub records_processed: usize,
pub ops_succeeded: usize,
pub ops_failed: usize,
}
#[cfg(feature = "composite")]
pub struct QueryBatch<'a, A: Authenticator> {
client: &'a ForceClient<A>,
query: String,
halt_on_error: bool,
}
#[cfg(feature = "composite")]
impl<'a, A: Authenticator> QueryBatch<'a, A> {
pub fn new(client: &'a ForceClient<A>, query: impl Into<String>) -> Self {
Self {
client,
query: query.into(),
halt_on_error: false,
}
}
#[must_use]
pub fn halt_on_error(mut self, halt: bool) -> Self {
self.halt_on_error = halt;
self
}
pub async fn run<T, F>(self, transform: F) -> Result<BatchStats>
where
T: DeserializeOwned,
F: Fn(T) -> Option<BatchOp>,
{
let mut stats = BatchStats::default();
let mut buffer = Vec::with_capacity(25);
let mut result = self.client.rest().query::<T>(&self.query).await?;
loop {
let records = std::mem::take(&mut result.records);
for record in records {
stats.records_processed += 1;
if let Some(op) = transform(record) {
buffer.push(op);
if buffer.len() >= 25 {
self.flush_batch(buffer, &mut stats).await?;
buffer = Vec::with_capacity(25);
}
}
}
if result.is_done() {
break;
}
if let Some(next_url) = result.next_records_url {
result = self.client.rest().query_more(&next_url).await?;
} else {
break;
}
}
if !buffer.is_empty() {
self.flush_batch(buffer, &mut stats).await?;
}
Ok(stats)
}
async fn flush_batch(&self, ops: Vec<BatchOp>, stats: &mut BatchStats) -> Result<()> {
if ops.is_empty() {
return Ok(());
}
let mut batch = self
.client
.composite()
.batch()
.halt_on_error(self.halt_on_error);
for op in ops {
batch = match op {
BatchOp::Update(sobject, id, fields) => batch.patch(&sobject, &id, fields)?,
BatchOp::Delete(sobject, id) => batch.delete(&sobject, &id)?,
BatchOp::Create(sobject, fields) => batch.post(&sobject, fields)?,
};
}
let response = batch.execute().await?;
for res in response.results {
if (200..=299).contains(&res.status_code) {
stats.ops_succeeded += 1;
} else {
stats.ops_failed += 1;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::builder;
use crate::test_support::{MockAuthenticator, Must};
use serde_json::json;
use wiremock::matchers::{method, path, query_param};
use wiremock::{Mock, MockServer, ResponseTemplate};
async fn create_mock_server() -> MockServer {
MockServer::start().await
}
async fn create_test_client(mock_server: &MockServer) -> ForceClient<MockAuthenticator> {
let auth = MockAuthenticator::new("test_token", &mock_server.uri());
builder().authenticate(auth).build().await.must()
}
#[tokio::test]
async fn test_query_batch_pagination() {
let mock_server = create_mock_server().await;
let client = create_test_client(&mock_server).await;
let mut page1_records = Vec::new();
for i in 0..20 {
page1_records.push(json!({
"attributes": { "type": "Account", "url": format!("/services/data/v60.0/sobjects/Account/0010000000000{:02}AAA", i) },
"Id": format!("0010000000000{:02}AAA", i)
}));
}
let mut page2_records = Vec::new();
for i in 20..30 {
page2_records.push(json!({
"attributes": { "type": "Account", "url": format!("/services/data/v60.0/sobjects/Account/0010000000000{:02}AAA", i) },
"Id": format!("0010000000000{:02}AAA", i)
}));
}
Mock::given(method("GET"))
.and(path("/services/data/v60.0/query"))
.and(query_param("q", "SELECT Id FROM Account"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 30,
"done": false,
"nextRecordsUrl": "/services/data/v60.0/query/01gD0000002HU6K",
"records": page1_records
})))
.mount(&mock_server)
.await;
Mock::given(method("GET"))
.and(path("/services/data/v60.0/query/01gD0000002HU6K"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 30,
"done": true,
"records": page2_records
})))
.mount(&mock_server)
.await;
let mut responses_page1 = Vec::new();
for _ in 0..25 {
responses_page1.push(json!({ "statusCode": 204, "result": null }));
}
let mut responses_page2 = Vec::new();
for _ in 0..5 {
responses_page2.push(json!({ "statusCode": 204, "result": null }));
}
Mock::given(method("POST"))
.and(path("/services/data/v60.0/composite/batch"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"hasErrors": false,
"results": responses_page2
})))
.up_to_n_times(1)
.mount(&mock_server)
.await;
Mock::given(method("POST"))
.and(path("/services/data/v60.0/composite/batch"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"hasErrors": false,
"results": responses_page1
})))
.up_to_n_times(1)
.mount(&mock_server)
.await;
let query = "SELECT Id FROM Account";
let query_batch = QueryBatch::new(&client, query);
let stats = query_batch
.run::<crate::types::DynamicSObject, _>(|record| {
let id = record.get_field_as::<String>("Id").ok().flatten().must();
Some(BatchOp::Delete(record.object_type().to_string(), id))
})
.await
.must();
println!("test stats: {:?}", stats);
assert_eq!(stats.records_processed, 30);
assert_eq!(stats.ops_succeeded, 30);
assert_eq!(stats.ops_failed, 0);
}
#[tokio::test]
async fn test_query_batch_partial_failures() {
let mock_server = create_mock_server().await;
let client = create_test_client(&mock_server).await;
let mut records = Vec::new();
for i in 0..10 {
records.push(json!({
"attributes": { "type": "Contact", "url": format!("/services/data/v60.0/sobjects/Contact/0030000000000{:02}AAA", i) },
"Id": format!("0030000000000{:02}AAA", i)
}));
}
Mock::given(method("GET"))
.and(path("/services/data/v60.0/query"))
.and(query_param("q", "SELECT Id FROM Contact"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"totalSize": 10,
"done": true,
"records": records
})))
.mount(&mock_server)
.await;
let mut results = Vec::new();
for i in 0..10 {
if i % 2 == 0 {
results.push(json!({ "statusCode": 204, "result": null }));
} else {
results.push(json!({ "statusCode": 400, "result": [{"message": "Bad Request", "errorCode": "INVALID_FIELD"}] }));
}
}
Mock::given(method("POST"))
.and(path("/services/data/v60.0/composite/batch"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"hasErrors": true,
"results": results
})))
.mount(&mock_server)
.await;
let query = "SELECT Id FROM Contact";
let query_batch = QueryBatch::new(&client, query);
let stats = query_batch
.run::<crate::types::DynamicSObject, _>(|record| {
let id = record.get_field_as::<String>("Id").ok().flatten().must();
Some(BatchOp::Delete(record.object_type().to_string(), id))
})
.await
.must();
assert_eq!(stats.records_processed, 10);
assert_eq!(stats.ops_succeeded, 5);
assert_eq!(stats.ops_failed, 5);
}
}