use super::CompositeHandler;
use crate::api::soql::SoqlQueryBuilder;
use crate::auth::Authenticator;
use crate::error::{ForceError, Result};
use crate::types::validator;
use serde::{Deserialize, Serialize};
use serde_json::Value;
fn validate_reference_id(id: &str) -> Result<()> {
validator::validate_identifier(id, "Reference ID")
}
fn validate_graph_id(id: &str) -> Result<()> {
if id.is_empty() {
return Err(ForceError::InvalidInput("ID cannot be empty".to_string()));
}
if id.contains('/') || id.contains("..") || id.contains('\\') || id.contains('?') {
return Err(ForceError::InvalidInput(format!(
"ID contains invalid path traversal characters: {}",
id
)));
}
Ok(())
}
#[derive(Debug)]
pub struct CompositeGraphRequest<A: Authenticator> {
handler: CompositeHandler<A>,
graphs: Vec<Graph>,
}
impl<A: Authenticator> CompositeGraphRequest<A> {
pub(crate) fn new(handler: CompositeHandler<A>) -> Self {
Self {
handler,
graphs: Vec::with_capacity(15),
}
}
pub fn add_graph(mut self, graph: Graph) -> Result<Self> {
let current_subrequests: usize =
self.graphs.iter().map(|g| g.composite_request.len()).sum();
if current_subrequests + graph.composite_request.len() > 500 {
return Err(ForceError::InvalidInput(
"Composite Graph limit of 500 total subrequests exceeded".to_string(),
));
}
self.graphs.push(graph);
Ok(self)
}
pub async fn execute(self) -> Result<GraphResponse> {
if self.graphs.is_empty() {
return Err(ForceError::Serialization(
crate::error::SerializationError::InvalidFormat(
"Graph request cannot be empty".to_string(),
),
));
}
let url = self.handler.inner.resolve_url("composite/graph").await?;
let request_body = GraphRequestBody {
graphs: self.graphs,
};
let request = self
.handler
.inner
.post(&url)
.json(&request_body)
.build()
.map_err(crate::error::HttpError::from)?;
self.handler
.inner
.send_request_and_decode(request, "Composite Graph failed")
.await
}
}
#[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Graph {
pub graph_id: String,
pub composite_request: Vec<GraphRequest>,
}
impl Graph {
#[must_use]
pub fn new(graph_id: impl Into<String>) -> Self {
Self {
graph_id: graph_id.into(),
composite_request: Vec::with_capacity(15),
}
}
pub fn add_request(mut self, request: GraphRequest) -> Result<Self> {
if self.composite_request.len() >= 500 {
return Err(ForceError::InvalidInput(
"Graph size limit of 500 requests exceeded".to_string(),
));
}
self.composite_request.push(request);
Ok(self)
}
pub fn get(self, sobject: &str, id: &str, reference_id: &str) -> Result<Self> {
validator::validate_sobject_name(sobject)?;
validate_graph_id(id)?;
validate_reference_id(reference_id)?;
self.add_request(GraphRequest::new(
"GET",
crate::api::path_utils::format_sobject_path(sobject, Some(id)),
reference_id,
)?)
}
pub fn post(self, sobject: &str, body: Value, reference_id: &str) -> Result<Self> {
validator::validate_sobject_name(sobject)?;
validate_reference_id(reference_id)?;
self.add_request(
GraphRequest::new(
"POST",
crate::api::path_utils::format_sobject_path(sobject, None),
reference_id,
)?
.body(body),
)
}
pub fn patch(self, sobject: &str, id: &str, body: Value, reference_id: &str) -> Result<Self> {
validator::validate_sobject_name(sobject)?;
validate_graph_id(id)?;
validate_reference_id(reference_id)?;
self.add_request(
GraphRequest::new(
"PATCH",
crate::api::path_utils::format_sobject_path(sobject, Some(id)),
reference_id,
)?
.body(body),
)
}
pub fn delete(self, sobject: &str, id: &str, reference_id: &str) -> Result<Self> {
validator::validate_sobject_name(sobject)?;
validate_graph_id(id)?;
validate_reference_id(reference_id)?;
self.add_request(GraphRequest::new(
"DELETE",
crate::api::path_utils::format_sobject_path(sobject, Some(id)),
reference_id,
)?)
}
#[allow(clippy::needless_pass_by_value)] pub fn query(self, query_builder: SoqlQueryBuilder, reference_id: &str) -> Result<Self> {
validate_reference_id(reference_id)?;
let url = crate::api::soql::encode_soql_query_url(&query_builder)?;
self.add_request(GraphRequest::new("GET", url, reference_id)?)
}
}
#[derive(Debug, Serialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct GraphRequest {
pub method: String,
pub url: String,
pub reference_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub body: Option<Value>,
}
impl GraphRequest {
pub fn new(
method: impl Into<String>,
url: impl Into<String>,
reference_id: impl Into<String>,
) -> Result<Self> {
let url_str = url.into();
validator::validate_url_path(&url_str)?;
Ok(Self {
method: method.into(),
url: url_str,
reference_id: reference_id.into(),
body: None,
})
}
#[must_use]
pub fn body(mut self, body: Value) -> Self {
self.body = Some(body);
self
}
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct GraphRequestBody {
graphs: Vec<Graph>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GraphResponse {
pub graphs: Vec<GraphResult>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GraphResult {
pub graph_id: String,
pub is_successful: bool,
#[serde(default)]
pub composite_response: Vec<GraphSubResponse>,
#[serde(default)]
pub graph_response: Option<GraphErrorResponse>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GraphErrorResponse {
pub composite_response: Vec<GraphSubResponse>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GraphSubResponse {
pub http_status_code: u16,
pub reference_id: String,
pub body: Option<Value>,
pub http_headers: Option<Value>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::builder as client_builder;
use crate::test_support::{MockAuthenticator, Must};
use serde_json::json;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
async fn create_builder() -> CompositeGraphRequest<MockAuthenticator> {
let auth = MockAuthenticator::new("token", "https://test.salesforce.com");
let client = client_builder().authenticate(auth).build().await.must();
client.composite().graph()
}
#[tokio::test]
async fn test_composite_graph_total_limit() {
let mut builder = create_builder().await;
let mut graph1 = Graph::new("graph1");
for i in 0..250 {
graph1 = graph1
.get("Account", "001000000000000AAA", &format!("ref{}", i))
.must();
}
builder = builder.add_graph(graph1).must();
let mut graph2 = Graph::new("graph2");
for i in 0..251 {
graph2 = graph2
.get("Account", "001000000000000AAA", &format!("ref2_{}", i))
.must();
}
let result = builder.add_graph(graph2);
assert!(matches!(result, Err(ForceError::InvalidInput(_))));
}
#[test]
fn test_graph_serialization() {
let mut graph = Graph::new("graph1");
graph = graph
.post("Account", json!({"Name": "RefAccount"}), "refAccount")
.must();
graph = graph
.post(
"Contact",
json!({
"LastName": "Doe",
"AccountId": "@{refAccount.id}"
}),
"refContact",
)
.must();
let req = GraphRequestBody {
graphs: vec![graph],
};
let json = serde_json::to_string(&req).must();
assert!(json.contains("\"graphId\":\"graph1\""));
assert!(json.contains("\"referenceId\":\"refAccount\""));
assert!(json.contains("\"referenceId\":\"refContact\""));
assert!(json.contains("\"AccountId\":\"@{refAccount.id}\""));
}
#[tokio::test]
async fn test_graph_execute_success() {
let mock_server = MockServer::start().await;
let auth = MockAuthenticator::new("token", &mock_server.uri());
let client = client_builder().authenticate(auth).build().await.must();
let mut graph = Graph::new("graph1");
graph = graph
.post("Account", json!({"Name": "Test"}), "acc1")
.must();
let builder = client.composite().graph().add_graph(graph).must();
let response_json = json!({
"graphs": [
{
"graphId": "graph1",
"isSuccessful": true,
"compositeResponse": [
{
"body": {
"id": "001...",
"success": true,
"errors": []
},
"httpHeaders": {},
"httpStatusCode": 201,
"referenceId": "acc1"
}
]
}
]
});
Mock::given(method("POST"))
.and(path("/services/data/v60.0/composite/graph"))
.respond_with(ResponseTemplate::new(200).set_body_json(response_json))
.mount(&mock_server)
.await;
let response = builder.execute().await.must();
assert_eq!(response.graphs.len(), 1);
assert!(response.graphs[0].is_successful);
assert_eq!(response.graphs[0].graph_id, "graph1");
assert_eq!(
response.graphs[0].composite_response[0].reference_id,
"acc1"
);
}
#[tokio::test]
async fn test_graph_execute_empty() {
let builder = create_builder().await;
let result = builder.execute().await;
let Err(ForceError::Serialization(e)) = result else {
panic!("Expected Serialization error, got {:?}", result);
};
assert!(e.to_string().contains("Graph request cannot be empty"));
}
#[test]
fn test_graph_query_encoding() {
let query = SoqlQueryBuilder::new()
.select(&["Id", "Name"])
.from("Account")
.where_eq("Name", "Acme & Co.");
let graph = Graph::new("graph1").query(query, "refQuery").must();
assert_eq!(graph.composite_request.len(), 1);
let req = &graph.composite_request[0];
assert_eq!(req.method, "GET");
assert_eq!(req.reference_id, "refQuery");
assert!(req.body.is_none());
let expected_url =
"query?q=SELECT+Id%2C+Name+FROM+Account+WHERE+Name+%3D+%27Acme+%26+Co.%27";
assert_eq!(req.url, expected_url);
}
#[test]
fn test_validate_reference_id() {
assert!(validate_reference_id("valid_id_123").is_ok());
assert!(validate_reference_id("validId").is_ok());
assert!(validate_reference_id("").is_err());
assert!(validate_reference_id("invalid ref id! @#$").is_err());
assert!(validate_reference_id("invalid-ref").is_err());
}
#[test]
fn test_validate_graph_id() {
assert!(validate_graph_id("001000000000000").is_ok());
assert!(validate_graph_id("@{ref.id}").is_ok());
assert!(validate_graph_id("validId").is_ok());
assert!(validate_graph_id("").is_err());
assert!(validate_graph_id("some/path").is_err());
assert!(validate_graph_id("..").is_err());
assert!(validate_graph_id("path\\test").is_err());
assert!(validate_graph_id("path?query").is_err());
}
#[test]
fn test_graph_post_patch_delete() {
let mut graph = Graph::new("graph1");
graph = graph
.post("Account", json!({"Name": "Test"}), "refPost")
.must();
graph = graph
.patch(
"Account",
"001000000000000AAA",
json!({"Name": "Updated"}),
"refPatch",
)
.must();
graph = graph
.delete("Account", "001000000000000AAA", "refDelete")
.must();
assert_eq!(graph.composite_request.len(), 3);
let post_req = &graph.composite_request[0];
assert_eq!(post_req.method, "POST");
assert_eq!(post_req.reference_id, "refPost");
let patch_req = &graph.composite_request[1];
assert_eq!(patch_req.method, "PATCH");
assert_eq!(patch_req.reference_id, "refPatch");
let delete_req = &graph.composite_request[2];
assert_eq!(delete_req.method, "DELETE");
assert_eq!(delete_req.reference_id, "refDelete");
}
#[test]
fn test_graph_size_limit() {
let mut graph = Graph::new("graph1");
for i in 0..500 {
graph = graph
.get("Account", "001000000000000AAA", &format!("ref{}", i))
.must();
}
let result = graph.get("Account", "001000000000000AAA", "ref501");
assert!(
matches!(result, Err(ForceError::InvalidInput(ref msg)) if msg.contains("limit of 500"))
);
}
#[test]
fn test_havoc_path_traversal() {
let graph = Graph::new("graph1");
let result = graph.get("Account", "../../../../../etc/passwd", "ref1");
let Err(crate::error::ForceError::InvalidInput(msg)) = result else {
panic!(
"Expected InvalidInput error for path traversal, got: {:?}",
result
);
};
assert!(msg.contains("invalid path traversal characters"));
}
#[test]
fn test_havoc_invalid_reference_id() {
let graph = Graph::new("graph1");
let result = graph.get("Account", "001xx000003DHP0AAO", "invalid ref id! @#$");
let Err(crate::error::ForceError::InvalidInput(msg)) = result else {
panic!(
"Expected InvalidInput error for invalid reference id, got: {:?}",
result
);
};
assert!(msg.contains("Reference ID contains invalid characters"));
}
#[test]
fn test_validate_reference_id_table() {
let valid_ids = vec!["ref1", "Ref_2", "A", "1", "valid_ref_id_123"];
let invalid_ids = vec![
"", "ref-1", "ref 1", "ref!1", "ref@1", "ref#1", "ref$1", "ref%1", "ref^1", "ref&1",
"ref*1", "ref(1", "ref)1", "ref+1", "ref=1", "ref{1", "ref}1", "ref[1", "ref]1",
"ref|1", "ref\\1", "ref:1", "ref;1", "ref\"1", "ref'1", "ref<1", "ref>1", "ref,1",
"ref.1", "ref?1", "ref/1",
];
for id in valid_ids {
assert!(
validate_reference_id(id).is_ok(),
"Expected {} to be valid",
id
);
}
for id in invalid_ids {
let result = validate_reference_id(id);
assert!(result.is_err(), "Expected {} to be invalid", id);
let Err(ForceError::InvalidInput(msg)) = result else {
panic!("Expected InvalidInput error for {}", id);
};
if id.is_empty() {
assert_eq!(msg, "Reference ID cannot be empty");
} else {
assert!(msg.contains("Reference ID contains invalid characters:"));
}
}
}
#[test]
fn test_graph_request_body() {
let req = GraphRequest::new("POST", "sobjects/Account", "ref1")
.must()
.body(serde_json::json!({"Name": "Test"}));
assert_eq!(req.body.must()["Name"], "Test");
}
#[test]
fn test_validate_graph_id_table() {
let valid_ids = vec!["001000000000000AAA", "@{ref1.id}"];
let invalid_ids = vec![
"",
"../../../etc/passwd",
"../something",
"path\\to",
"path/to",
"id?param=1",
];
for id in valid_ids {
let graph = Graph::new("graph1");
let result = graph.get("Account", id, "ref1");
assert!(result.is_ok(), "Expected {} to be valid graph id", id);
}
for id in invalid_ids {
let graph = Graph::new("graph1");
let result = graph.get("Account", id, "ref1");
assert!(result.is_err(), "Expected {} to be invalid graph id", id);
let Err(ForceError::InvalidInput(msg)) = result else {
panic!("Expected InvalidInput error for {}", id);
};
if id.is_empty() {
assert_eq!(msg, "ID cannot be empty");
} else {
assert!(msg.contains("ID contains invalid path traversal characters:"));
}
}
}
}