use std::{collections::HashMap, fs::File, io::Write, time::Duration};
use reqwest::header::{HeaderMap, HeaderValue};
use sindri_openapi::{
apis::{
circuit_download, circuit_status,
circuits_api::{circuit_delete, circuit_detail, CircuitDetailError},
configuration::Configuration,
proof_status,
proofs_api::{proof_delete, proof_detail, ProofDetailError},
Error,
},
models::{CircuitInfoResponse, JobStatus, ProofInfoResponse},
};
use tracing::{debug, info, warn};
use crate::{
custom_middleware::{
retry_client, HeaderDeduplicatorMiddleware, LoggingMiddleware,
ZstdRequestCompressionMiddleware,
},
types::{CircuitInfo, ProofInput},
};
#[cfg(any(feature = "record", feature = "replay"))]
use crate::custom_middleware::vcr_middleware;
#[cfg(feature = "rich-terminal")]
use crate::utils::ClockProgressBar;
#[derive(Default, Debug, Clone)]
pub struct AuthOptions {
pub api_key: Option<String>,
pub base_url: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PollingOptions {
pub interval: Duration,
pub timeout: Option<Duration>,
}
impl Default for PollingOptions {
fn default() -> Self {
Self {
interval: Duration::from_secs(1),
timeout: Some(Duration::from_secs(60 * 10)),
}
}
}
#[derive(Debug)]
pub struct SindriClient {
pub(crate) config: Configuration,
pub polling_options: PollingOptions,
}
impl Default for SindriClient {
fn default() -> Self {
Self::new(None, None)
}
}
impl SindriClient {
pub fn new(auth_options: Option<AuthOptions>, polling_options: Option<PollingOptions>) -> Self {
let mut headers = HeaderMap::new();
headers.insert(
"Sindri-Client",
HeaderValue::from_str(
format!("{}/v{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")).as_str(),
)
.expect("Could not insert default rust client header"),
);
#[allow(unused_mut)] let mut client_builder = reqwest_middleware::ClientBuilder::new(
reqwest::Client::builder()
.default_headers(headers)
.zstd(true)
.build()
.expect("Could not build client"),
)
.with(HeaderDeduplicatorMiddleware)
.with(LoggingMiddleware)
.with(retry_client(None))
.with(ZstdRequestCompressionMiddleware);
#[cfg(any(feature = "record", feature = "replay"))]
{
if !cfg!(test) {
let bundle = std::env::var("VCR_PATH")
.unwrap_or_else(|_| "tests/recordings/replay.vcr.json".to_string());
let bundle_path = std::path::PathBuf::from(&bundle);
#[cfg(feature = "replay")]
if !bundle_path.exists() {
panic!("Recording not found at: {}", bundle_path.display());
}
client_builder = client_builder.with(vcr_middleware(bundle_path));
}
}
let client = client_builder.build();
let auth = auth_options.unwrap_or_default();
let base_url = auth
.base_url
.or_else(|| std::env::var("SINDRI_BASE_URL").ok())
.unwrap_or_else(|| "https://sindri.app".to_string());
let api_key = auth
.api_key
.or_else(|| std::env::var("SINDRI_API_KEY").ok());
let config = Configuration {
base_path: base_url,
bearer_access_token: api_key,
client,
..Default::default()
};
Self {
config,
polling_options: polling_options.unwrap_or_default(),
}
}
pub fn api_key(&self) -> Option<&str> {
self.config.bearer_access_token.as_deref()
}
pub fn base_url(&self) -> &str {
&self.config.base_path
}
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
self.config.bearer_access_token = Some(api_key.into());
self
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.config.base_path = base_url.into();
self
}
pub fn with_polling_interval(mut self, interval: Duration) -> Self {
self.polling_options.interval = interval;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.polling_options.timeout = Some(timeout);
self
}
pub fn with_no_timeout(mut self) -> Self {
self.polling_options.timeout = None;
self
}
pub async fn create_circuit(
&self,
project: String,
tags: Option<Vec<String>>,
meta: Option<HashMap<String, String>>,
) -> Result<CircuitInfoResponse, Box<dyn std::error::Error>> {
let response = self.request_build(project, tags, meta).await?;
let circuit_id = response.id();
info!("Circuit created with ID: {}", circuit_id);
#[cfg(feature = "rich-terminal")]
let mut current_status = *response.status();
#[cfg(feature = "rich-terminal")]
let pb = ClockProgressBar::new(&format!("Job status: {}", current_status));
let start_time = std::time::Instant::now();
let mut status = circuit_status(&self.config, circuit_id).await?;
debug!("Initial circuit status: {:?}", status.status);
while !matches!(status.status, JobStatus::Ready | JobStatus::Failed) {
if let Some(timeout) = self.polling_options.timeout {
if start_time.elapsed() > timeout {
warn!("Circuit compilation timed out after {:?}", timeout);
return Err(
"Circuit compilation did not complete within timeout duration".into(),
);
}
}
std::thread::sleep(self.polling_options.interval);
status = circuit_status(&self.config, circuit_id).await?;
#[cfg(feature = "rich-terminal")]
if status.status != current_status {
pb.update_message(&format!("Job status: {}", status.status));
current_status = status.status;
}
}
match status.status {
JobStatus::Ready => info!(
"Circuit compilation completed successfully after {:?}",
start_time.elapsed()
),
JobStatus::Failed => warn!(
"Circuit compilation failed after {:?}",
start_time.elapsed()
),
_ => unreachable!(),
}
let circuit_info = circuit_detail(&self.config, circuit_id, None).await?;
Ok(circuit_info)
}
pub fn create_circuit_blocking(
&self,
project: String,
tags: Option<Vec<String>>,
meta: Option<HashMap<String, String>>,
) -> Result<CircuitInfoResponse, Box<dyn std::error::Error>> {
let runtime = tokio::runtime::Runtime::new()?;
runtime.block_on(self.create_circuit(project, tags, meta))
}
pub async fn delete_circuit(&self, circuit_id: &str) -> Result<(), Box<dyn std::error::Error>> {
info!("Deleting circuit with ID: {}", circuit_id);
circuit_delete(&self.config, circuit_id).await?;
Ok(())
}
pub async fn clone_circuit(
&self,
circuit_id: &str,
download_path: String,
) -> Result<(), Box<dyn std::error::Error>> {
info!("Cloning circuit with ID: {}", circuit_id);
debug!("Download path: {}", download_path);
let download_response = circuit_download(&self.config, circuit_id, None).await?;
debug!("Circuit downloaded successfully");
let mut file = File::create(download_path.clone())?;
file.write_all(&download_response.bytes().await?)?;
info!("Circuit written to {}", download_path);
Ok(())
}
pub fn clone_circuit_blocking(
&self,
circuit_id: &str,
download_path: String,
) -> Result<(), Box<dyn std::error::Error>> {
let runtime = tokio::runtime::Runtime::new()?;
runtime.block_on(self.clone_circuit(circuit_id, download_path))
}
pub async fn get_circuit(
&self,
circuit_id: &str,
include_verification_key: Option<bool>,
) -> Result<CircuitInfoResponse, Error<CircuitDetailError>> {
info!("Getting circuit with ID: {}", circuit_id);
let circuit_info =
circuit_detail(&self.config, circuit_id, include_verification_key).await?;
Ok(circuit_info)
}
pub async fn prove_circuit(
&self,
circuit_id: &str,
proof_input: impl Into<ProofInput>,
meta: Option<HashMap<String, String>>,
verify: Option<bool>,
prover_implementation: Option<String>,
) -> Result<ProofInfoResponse, Box<dyn std::error::Error>> {
let proof_info = self
.request_proof(circuit_id, proof_input, meta, verify, prover_implementation)
.await?;
let proof_id = proof_info.proof_id;
info!("Proof generation started with ID: {}", proof_id);
let mut status = proof_status(&self.config, &proof_id).await?;
debug!("Initial proof status: {:?}", status.status);
let start_time = std::time::Instant::now();
while !matches!(status.status, JobStatus::Ready | JobStatus::Failed) {
if let Some(timeout) = self.polling_options.timeout {
if start_time.elapsed() > timeout {
warn!("Proof generation timed out after {:?}", timeout);
return Err("Proof generation did not complete within timeout duration".into());
}
}
std::thread::sleep(self.polling_options.interval);
status = proof_status(&self.config, &proof_id).await?;
}
match status.status {
JobStatus::Ready => info!(
"Proof generation completed successfully after {:?}",
start_time.elapsed()
),
JobStatus::Failed => warn!("Proof generation failed after {:?}", start_time.elapsed()),
_ => unreachable!(),
}
let proof_info = proof_detail(&self.config, &proof_id, None, None, None, None).await?;
Ok(proof_info)
}
pub fn prove_circuit_blocking(
&self,
circuit_id: &str,
proof_input: impl Into<ProofInput>,
meta: Option<HashMap<String, String>>,
verify: Option<bool>,
prover_implementation: Option<String>,
) -> Result<ProofInfoResponse, Box<dyn std::error::Error>> {
let runtime = tokio::runtime::Runtime::new()?;
runtime.block_on(self.prove_circuit(
circuit_id,
proof_input,
meta,
verify,
prover_implementation,
))
}
pub async fn delete_proof(&self, proof_id: &str) -> Result<(), Box<dyn std::error::Error>> {
info!("Deleting proof with ID: {}", proof_id);
proof_delete(&self.config, proof_id).await?;
Ok(())
}
pub async fn get_proof(
&self,
proof_id: &str,
include_proof: Option<bool>,
include_public: Option<bool>,
include_verification_key: Option<bool>,
) -> Result<ProofInfoResponse, Error<ProofDetailError>> {
let proof_info = proof_detail(
&self.config,
proof_id,
include_proof,
include_public,
None,
include_verification_key,
)
.await?;
Ok(proof_info)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::BoojumCircuitInfoResponse;
use tracing_test::traced_test;
use wiremock::{
matchers::{header_exists, method, path},
Mock, MockServer, ResponseTemplate,
};
#[test]
fn test_new_client_with_options() {
let auth_options = AuthOptions {
api_key: Some("test_key".to_string()),
base_url: Some("https://fake.sindri.app".to_string()),
};
let client = SindriClient::new(Some(auth_options), None);
assert_eq!(client.api_key(), Some("test_key"));
assert_eq!(client.base_url(), "https://fake.sindri.app");
}
#[test]
fn test_builder_methods() {
let client = SindriClient::default()
.with_api_key("test_key")
.with_base_url("https://example.com")
.with_polling_interval(Duration::from_secs(5))
.with_timeout(Duration::from_secs(300));
assert_eq!(client.api_key(), Some("test_key"));
assert_eq!(client.base_url(), "https://example.com");
assert_eq!(client.polling_options.interval, Duration::from_secs(5));
assert_eq!(
client.polling_options.timeout,
Some(Duration::from_secs(300))
);
}
#[test]
fn test_with_no_timeout() {
let client = SindriClient::default()
.with_timeout(Duration::from_secs(300))
.with_no_timeout();
assert_eq!(client.polling_options.timeout, None);
}
#[test]
fn test_new_client_with_env_vars() {
temp_env::with_vars(
vec![
("SINDRI_API_KEY", Some("env_test_key")),
("SINDRI_BASE_URL", Some("https://example.com")),
],
|| {
let client = SindriClient::new(None, None);
assert_eq!(client.api_key(), Some("env_test_key"));
assert_eq!(client.base_url(), "https://example.com");
},
);
}
#[test]
fn test_auth_options_override_env_vars() {
temp_env::with_vars(
vec![
("SINDRI_API_KEY", Some("env_test_key")),
("SINDRI_BASE_URL", Some("https://example.com")),
],
|| {
let auth_options = AuthOptions {
api_key: Some("test_key".to_string()),
base_url: Some("https://other.example.com".to_string()),
};
let client = SindriClient::new(Some(auth_options), None);
assert_eq!(client.api_key(), Some("test_key"));
assert_eq!(client.base_url(), "https://other.example.com");
},
);
}
#[test]
fn test_new_client_auth_defaults() {
temp_env::with_vars(
vec![
("SINDRI_API_KEY", None::<String>),
("SINDRI_BASE_URL", None::<String>),
],
|| {
let client = SindriClient::new(None, None);
assert_eq!(client.api_key(), None);
assert_eq!(client.base_url(), "https://sindri.app");
},
);
}
#[test]
fn test_new_client_config_defaults() {
let client = SindriClient::new(None, None);
let config = client.config;
assert_eq!(config.basic_auth, None);
assert_eq!(config.oauth_access_token, None);
assert!(config.api_key.is_none());
}
#[test]
fn test_polling_options_custom() {
let polling_options = PollingOptions {
interval: Duration::from_secs(5),
timeout: Some(Duration::from_secs(300)), };
let client = SindriClient::new(None, Some(polling_options));
assert_eq!(client.polling_options.interval, Duration::from_secs(5));
assert_eq!(
client.polling_options.timeout,
Some(Duration::from_secs(300))
);
}
#[test]
fn test_post_client_init_polling_tweaks() {
let mut client = SindriClient::new(None, None);
assert_eq!(client.polling_options.interval, Duration::from_secs(1));
assert_eq!(
client.polling_options.timeout,
Some(Duration::from_secs(600))
);
client.polling_options.interval = Duration::from_secs(5);
client.polling_options.timeout = Some(Duration::from_secs(300));
assert_eq!(client.polling_options.interval, Duration::from_secs(5));
assert_eq!(
client.polling_options.timeout,
Some(Duration::from_secs(300))
);
}
#[tokio::test]
async fn test_client_default_header() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(header_exists("sindri-client"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock_server)
.await;
let outer_client = SindriClient::new(None, None);
let inner_client = &outer_client.config.client;
let request = inner_client.get(mock_server.uri()).build().unwrap();
let response = inner_client.execute(request).await.unwrap();
assert_eq!(response.status(), 200);
}
#[tokio::test]
async fn test_circuit_create_tag_validation() {
let client = SindriClient::new(None, None);
let mut tags = vec!["test_t@g".to_string()];
let mut circuit = client
.create_circuit("fake_path".to_string(), Some(tags), None)
.await;
assert!(circuit.is_err());
assert!(circuit.unwrap_err().to_string().contains("not a valid tag"));
tags = vec![
"test_tag".to_string(),
"1-2-3-4-5-6".to_string(),
"ಠ_ಠ".to_string(),
];
circuit = client
.create_circuit("fake_path".to_string(), Some(tags), None)
.await;
assert!(circuit.is_err());
assert!(circuit.unwrap_err().to_string().contains("ಠ_ಠ"));
}
async fn mock_compile_server() -> MockServer {
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(method("POST"))
.and(path("/api/v1/circuit/create"))
.respond_with(
ResponseTemplate::new(200).set_body_json(CircuitInfoResponse::Boojum(Box::new(
BoojumCircuitInfoResponse {
circuit_id: "test_circuit_123".to_string(),
..Default::default()
},
))),
)
.mount(&mock_server)
.await;
wiremock::Mock::given(method("GET"))
.and(path("/api/v1/circuit/test_circuit_123/status"))
.respond_with(
ResponseTemplate::new(200).set_body_json(CircuitInfoResponse::Boojum(Box::new(
BoojumCircuitInfoResponse {
status: JobStatus::Ready,
..Default::default()
},
))),
)
.mount(&mock_server)
.await;
wiremock::Mock::given(method("GET"))
.and(path("/api/v1/circuit/test_circuit_123/detail"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock_server)
.await;
mock_server
}
#[tokio::test]
#[traced_test]
async fn test_verbose_logging() {
let mock_server = mock_compile_server().await;
let auth_options = AuthOptions {
api_key: Some("test_key".to_string()),
base_url: Some(mock_server.uri()),
};
let client = SindriClient::new(Some(auth_options), None);
let temp_dir = tempfile::tempdir().unwrap();
let test_file = temp_dir.path().join("test.zip");
std::fs::write(&test_file, "test content").unwrap();
let _result = client
.create_circuit(test_file.to_str().unwrap().to_string(), None, None)
.await;
assert!(logs_contain("Creating new circuit from project"));
assert!(logs_contain("Uploading circuit to Sindri"));
assert!(logs_contain("Circuit created with ID: test_circuit_123"));
assert!(logs_contain("Circuit compilation completed"));
logs_assert(|lines: &[&str]| {
match lines
.iter()
.filter(|line| line.contains("Request sent"))
.count()
{
3 => Ok(()),
n => Err(format!(
"Expected three logs for request outbound, but found {}",
n
)),
}
});
logs_assert(|lines: &[&str]| {
match lines
.iter()
.filter(|line| line.contains("Response received"))
.count()
{
3 => Ok(()),
n => Err(format!(
"Expected three logs for response inbound, but found {}",
n
)),
}
});
}
}